SoftGNSS/MatrixFFT4096.py

50 lines
1.7 KiB
Python
Raw Normal View History

2025-10-22 16:08:12 +07:00
import torch
import numpy as np
class MatrixFFT4096:
def __init__(self, device="cpu"):
self.N = 4096
self.device = torch.device(device)
# Luôn xây ma trận bằng complex64 để nhẹ hơn
n = np.arange(self.N)
k = n.reshape((self.N, 1))
W = np.exp(-2j * np.pi * k * n / self.N).astype(np.complex64)
W_inv = np.exp(2j * np.pi * k * n / self.N).astype(np.complex64) / self.N
self.DFT = torch.tensor(W, dtype=torch.complex64, device=self.device)
self.IDFT = torch.tensor(W_inv, dtype=torch.complex64, device=self.device)
def _to_tensor(self, x):
"""Nhận numpy hoặc torch -> ép về torch.complex64 trên device"""
if isinstance(x, np.ndarray):
# numpy mặc định là complex128 -> ép về complex64
x = torch.tensor(x, dtype=torch.complex64, device=self.device)
elif isinstance(x, torch.Tensor):
# nếu là torch nhưng sai dtype -> ép lại complex64
if not torch.is_complex(x):
x = x.to(torch.complex64)
else:
x = x.to(torch.complex64)
x = x.to(self.device)
else:
raise TypeError("Input must be numpy.ndarray or torch.Tensor")
return x
def fft(self, x):
"""
FFT bằng nhân ma trận.
x: numpy.ndarray hoặc torch.Tensor, shape (B,4096) hoặc (4096,)
"""
x = self._to_tensor(x)
return torch.matmul(x, self.DFT.T)
def ifft(self, X):
"""
IFFT bằng nhân ma trận.
X: numpy.ndarray hoặc torch.Tensor
"""
X = self._to_tensor(X)
return torch.matmul(X, self.IDFT.T)