簡體   English   中英

輸入圖像的小波二維散射變換

[英]Wavelet 2D Scattering transform of an input image

我正在嘗試對輸入圖像進行 2D 散射變換。 當我運行以下代碼時,我收到此錯誤:“過濾器與乘法不兼容。”? 有人可以幫忙嗎? 謝謝!

import torch
from kymatio import Scattering2D
import numpy as np
import PIL
from PIL import Image

FILENAME = "add a png file path"
image = PIL.Image.open(FILENAME).convert("L")

a = np.array(image).astype(np.float32)
x = torch.from_numpy(a)
imageSize=x.shape

scattering = Scattering2D(J=2, shape=imageSize, L=8)
Sx = scattering.forward(x)

print(Sx.size()) 

對於具有相等寬度和高度(正方形,而不是矩形)的小型 1KB .png似乎可以正常工作:

import torch
from kymatio import Scattering2D
import numpy as np
import PIL
from PIL import Image

FILENAME = "/path/to/dir/small_size_1_KB.png"
image = PIL.Image.open(FILENAME).convert("L")

a = np.array(image).astype(np.float64)
x = torch.from_numpy(a)
imageSize = x.shape

scattering = Scattering2D(J=2, shape=imageSize, L=8)

Sx = scattering.forward(x)

print(Sx.size())

Output

torch.Size([81, 19, 19])

您遇到的錯誤在此方法( backend_torch.py )中,應該與張量大小有關:

def cdgmm(A, B, inplace=False):
    """
        Complex pointwise multiplication between (batched) tensor A and tensor B.

        Parameters
        ----------
        A : tensor
            input tensor with size (B, C, M, N, 2)
        B : tensor
            B is a complex tensor of size (M, N, 2)
        inplace : boolean, optional
            if set to True, all the operations are performed inplace

        Returns
        -------
        C : tensor
            output tensor of size (B, C, M, N, 2) such that:
            C[b, c, m, n, :] = A[b, c, m, n, :] * B[m, n, :]
    """
    A, B = A.contiguous(), B.contiguous()
    if A.size()[-3:] != B.size():
        raise RuntimeError('The filters are not compatible for multiplication!')

    if not iscomplex(A) or not iscomplex(B):
        raise TypeError('The input, filter and output should be complex')

    if B.ndimension() != 3:
        raise RuntimeError('The filters must be simply a complex array!')

    if type(A) is not type(B):
        raise RuntimeError('A and B should be same type!')


    C = A.new(A.size())

    A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3))
    A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3))

    B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i)
    B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r)

    C[..., 0].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_r - A_i * B_i
    C[..., 1].view(-1, C.size(-2)*C.size(-3))[:] = A_r * B_i + A_i * B_r

    return C if not inplace else A.copy_(C)

資源

https://github.com/edouardoyallon/pyscatwave/blob/master/scatwave/utils.py

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM