import torch import numpy as np class MatrixFFT4096: def __init__(self, device="cpu"): self.N = 4096 self.device = torch.device(device) # Luôn xây ma trận bằng complex64 để nhẹ hơn n = np.arange(self.N) k = n.reshape((self.N, 1)) W = np.exp(-2j * np.pi * k * n / self.N).astype(np.complex64) W_inv = np.exp(2j * np.pi * k * n / self.N).astype(np.complex64) / self.N self.DFT = torch.tensor(W, dtype=torch.complex64, device=self.device) self.IDFT = torch.tensor(W_inv, dtype=torch.complex64, device=self.device) def _to_tensor(self, x): """Nhận numpy hoặc torch -> ép về torch.complex64 trên device""" if isinstance(x, np.ndarray): # numpy mặc định là complex128 -> ép về complex64 x = torch.tensor(x, dtype=torch.complex64, device=self.device) elif isinstance(x, torch.Tensor): # nếu là torch nhưng sai dtype -> ép lại complex64 if not torch.is_complex(x): x = x.to(torch.complex64) else: x = x.to(torch.complex64) x = x.to(self.device) else: raise TypeError("Input must be numpy.ndarray or torch.Tensor") return x def fft(self, x): """ FFT bằng nhân ma trận. x: numpy.ndarray hoặc torch.Tensor, shape (B,4096) hoặc (4096,) """ x = self._to_tensor(x) return torch.matmul(x, self.DFT.T) def ifft(self, X): """ IFFT bằng nhân ma trận. X: numpy.ndarray hoặc torch.Tensor """ X = self._to_tensor(X) return torch.matmul(X, self.IDFT.T)