import torch import torch.fft as fft import numpy as np import matplotlib.pyplot as plt from initialize import Result class AcquisitionResultTorch(Result): def __init__(self, settings, device="cuda"): super().__init__(settings) self.device = device self._results = None self._channels = None # ================================================================ def acquire(self, longSignal): """ GPU/NPU-based FFT acquisition using PyTorch batch processing. Equivalent to SoftGNSS acquisition.m but vectorized across Doppler bins. """ settings = self._settings device = self.device dtype_c = torch.complex64 dtype_r = torch.float32 samplesPerCode = settings.samplesPerCode # --- Prepare two ms of signal (IQ interleaved) -------------------------- rawSignal1 = longSignal[0:2 * samplesPerCode] signalI1 = rawSignal1[0::2] signalQ1 = rawSignal1[1::2] signal1 = torch.tensor(signalI1 + 1j * signalQ1, dtype=dtype_c, device=device) rawSignal2 = longSignal[2 * samplesPerCode:4 * samplesPerCode] signalI2 = rawSignal2[0::2] signalQ2 = rawSignal2[1::2] signal2 = torch.tensor(signalI2 + 1j * signalQ2, dtype=dtype_c, device=device) # DC removal (long segment) longI = longSignal[0::2] longQ = longSignal[1::2] longIQ = torch.tensor(longI + 1j * longQ, dtype=dtype_c, device=device) signal0DC = longIQ - torch.mean(longIQ) # --- Generate frequency bins ------------------------------------------- numberOfFrqBins = int(round(settings.acqSearchBand * 2) + 1) frqBins = torch.linspace( settings.IF - settings.acqSearchBand / 2 * 1e3, settings.IF + settings.acqSearchBand / 2 * 1e3, numberOfFrqBins, device=device, dtype=dtype_r ) # --- Time phase array --------------------------------------------------- ts = 1.0 / settings.samplingFreq phasePoints = torch.arange(samplesPerCode, device=device, dtype=dtype_r) * 2 * np.pi * ts # --- Generate CA code table (on CPU, then move one by one) ------------- caCodesTable = settings.makeCaTable() carrFreq = np.zeros(32) codePhase = np.zeros(32) peakMetric = np.zeros(32) print("(") for PRN in range(len(settings.acqSatelliteList)): # --- Prepare CA code ------------------------------------------------ caCode = torch.tensor(caCodesTable[PRN, :], dtype=dtype_r, device=device) caCodeFreqDom = fft.fft(caCode).conj() # --- Generate carrier matrix [N, M] -------------------------------- carr = torch.exp(1j * (frqBins[:, None] * phasePoints[None, :])) # (N, M) # --- Mix and FFT ---------------------------------------------------- IQ1 = carr * signal1 IQ2 = carr * signal2 IQfreqDom1 = fft.fft(IQ1, dim=1) IQfreqDom2 = fft.fft(IQ2, dim=1) conv1 = IQfreqDom1 * caCodeFreqDom[None, :] conv2 = IQfreqDom2 * caCodeFreqDom[None, :] acqRes1 = torch.abs(fft.ifft(conv1, dim=1)) ** 2 acqRes2 = torch.abs(fft.ifft(conv2, dim=1)) ** 2 # --- Combine results (choose stronger between 1st & 2nd ms) -------- results = torch.maximum(acqRes1, acqRes2) # --- Peak detection ------------------------------------------------- max_per_bin, _ = torch.max(results, dim=1) frequencyBinIndex = int(torch.argmax(max_per_bin).cpu()) peakSize = float(torch.max(results).cpu()) codePhaseIdx = int(torch.argmax(torch.max(results, dim=0).values).cpu()) samplesPerCodeChip = int(round(settings.samplingFreq / settings.codeFreqBasis)) exclude1 = codePhaseIdx - samplesPerCodeChip exclude2 = codePhaseIdx + samplesPerCodeChip if exclude1 <= 0: codePhaseRange = torch.arange(exclude2, samplesPerCode + exclude1 + 1) % samplesPerCode elif exclude2 >= samplesPerCode - 1: codePhaseRange = torch.arange(exclude2 - samplesPerCode, exclude1) % samplesPerCode else: codePhaseRange = torch.cat([ torch.arange(0, exclude1 + 1), torch.arange(exclude2, samplesPerCode) ]) secondPeakSize = float(results[frequencyBinIndex, codePhaseRange.long()].max().cpu()) ratio = peakSize / (secondPeakSize + 1e-12) peakMetric[PRN] = ratio if settings.acqMinThreshold < ratio < settings.acqMaxThreshold: print(f"{PRN + 1:02d} ", end="") # --- Fine Doppler refinement using 10 ms ------------------------ caCode_cpu = settings.generateCAcode(PRN) caCode_torch = torch.tensor(caCode_cpu, dtype=dtype_r, device=device) codeValueIndex = torch.floor( ts * torch.arange(1, 10 * samplesPerCode + 1, device=device) / (1.0 / settings.codeFreqBasis) ).long() % 1023 longCaCode = caCode_torch[codeValueIndex] xCarrier = signal0DC[codePhaseIdx:codePhaseIdx + 10 * samplesPerCode] * longCaCode fftNumPts = int(8 * 2 ** np.ceil(np.log2(len(xCarrier)))) fftxc = torch.abs(fft.fft(xCarrier, fftNumPts)) # spectrum fftMaxIdx = int(torch.argmax(fftxc).cpu()) fftFreqBins = torch.arange(fftNumPts, device=device, dtype=dtype_r) * settings.samplingFreq / fftNumPts dopFreq = float(fftFreqBins[fftMaxIdx].cpu()) if fftMaxIdx > (fftNumPts + 1) / 2.0: dopFreq -= settings.samplingFreq carrFreq[PRN] = dopFreq codePhase[PRN] = codePhaseIdx # self.plot_acquisition_3d(results.detach().cpu().numpy(), # frqBins.detach().cpu().numpy(), # PRN, settings) else: print(". ", end="") print(")\n") acqResults = np.core.records.fromarrays( [carrFreq, codePhase, peakMetric], names="carrFreq,codePhase,peakMetric" ) self._results = acqResults return # ================================================================ def plot(self): """Bar plot of acquisition metric for all PRNs.""" assert isinstance(self._results, np.recarray) plt.figure() plt.bar(range(1, 33), self._results.peakMetric) plt.title("Acquisition Results (GPU-PyTorch)") plt.xlabel("PRN Number") plt.ylabel("Peak Metric (1st/2nd)") plt.grid(True) plt.show() # ================================================================ @staticmethod def plot_acquisition_3d(results, frqBins, PRN, settings): """ Plot 3D surface acquisition result for one PRN. results: matrix [numberOfFrqBins x samplesPerCode] frqBins: list of Doppler frequencies (Hz) """ from mpl_toolkits.mplot3d import Axes3D # noqa: F401 codePhases = np.arange(results.shape[1]) dopplers = np.array(frqBins) X, Y = np.meshgrid(codePhases, dopplers) Z = results fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111, projection='3d') surf = ax.plot_surface(X, Y / 1000.0, Z, cmap='viridis', linewidth=0, antialiased=False) fig.colorbar(surf, shrink=0.5, aspect=5) ax.set_title(f'Acquisition Surface (PRN {PRN + 1})') ax.set_xlabel('Code Phase (samples)') ax.set_ylabel('Doppler (kHz)') ax.set_zlabel('Correlation Power') plt.show()