简体   繁体   中英

Problem with the functions torch.rfft() and torch.irfft()

I need to run the code that was written for the old version of PyTorch. I have version 1.9. Starting from version 1.8, PyTorch introduced the functions torch.fft.rfft() and torch.fft.irfft() , which work differently from the old torch.rfft() and torch.irfft() . I couldn't figure out how to replace these functions so that this code works exactly the same as on the old version:

version 1.8:

    fU = torch.rfft( u, 1, onesided=False)
    U = torch.irfft(fU, 1, onesided=False)
    torch.fft(x, x.ndim)

Please help me

As you mentioned, torch.rfft() and torch.irfft() have differences with torch.fft.rfft() and torch.fft.irfft() in formats of inputs and outputs. (See https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7 )

Following the issue in GitHub page of PyTorch, I made changes for two-dimensional operation as below:

spectrum = torch.fft.rfft2(signal)
spectrum = torch.complex(spectrum[..., 0], spectrum[..., 1])
signal_recovered = torch.fft.irfft2(spectrum)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM