50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
class MatrixFFT4096:
|
||
|
|
def __init__(self, device="cpu"):
|
||
|
|
self.N = 4096
|
||
|
|
self.device = torch.device(device)
|
||
|
|
|
||
|
|
# Luôn xây ma trận bằng complex64 để nhẹ hơn
|
||
|
|
n = np.arange(self.N)
|
||
|
|
k = n.reshape((self.N, 1))
|
||
|
|
W = np.exp(-2j * np.pi * k * n / self.N).astype(np.complex64)
|
||
|
|
W_inv = np.exp(2j * np.pi * k * n / self.N).astype(np.complex64) / self.N
|
||
|
|
|
||
|
|
self.DFT = torch.tensor(W, dtype=torch.complex64, device=self.device)
|
||
|
|
self.IDFT = torch.tensor(W_inv, dtype=torch.complex64, device=self.device)
|
||
|
|
|
||
|
|
def _to_tensor(self, x):
|
||
|
|
"""Nhận numpy hoặc torch -> ép về torch.complex64 trên device"""
|
||
|
|
if isinstance(x, np.ndarray):
|
||
|
|
# numpy mặc định là complex128 -> ép về complex64
|
||
|
|
x = torch.tensor(x, dtype=torch.complex64, device=self.device)
|
||
|
|
elif isinstance(x, torch.Tensor):
|
||
|
|
# nếu là torch nhưng sai dtype -> ép lại complex64
|
||
|
|
if not torch.is_complex(x):
|
||
|
|
x = x.to(torch.complex64)
|
||
|
|
else:
|
||
|
|
x = x.to(torch.complex64)
|
||
|
|
x = x.to(self.device)
|
||
|
|
else:
|
||
|
|
raise TypeError("Input must be numpy.ndarray or torch.Tensor")
|
||
|
|
return x
|
||
|
|
|
||
|
|
def fft(self, x):
|
||
|
|
"""
|
||
|
|
FFT bằng nhân ma trận.
|
||
|
|
x: numpy.ndarray hoặc torch.Tensor, shape (B,4096) hoặc (4096,)
|
||
|
|
"""
|
||
|
|
x = self._to_tensor(x)
|
||
|
|
return torch.matmul(x, self.DFT.T)
|
||
|
|
|
||
|
|
def ifft(self, X):
|
||
|
|
"""
|
||
|
|
IFFT bằng nhân ma trận.
|
||
|
|
X: numpy.ndarray hoặc torch.Tensor
|
||
|
|
"""
|
||
|
|
X = self._to_tensor(X)
|
||
|
|
return torch.matmul(X, self.IDFT.T)
|