簡體   English   中英

如何解決 NeuroDiffEq 中的錯誤“mat1 和 mat2 形狀不能相乘(1000x1 和 3x512)”?

[英]How to resolve the error 'mat1 and mat2 shapes cannot be multiplied (1000x1 and 3x512)' in NeuroDiffEq?

我是神經網絡的新手,對它們的使用方式有基本的了解。 我正在嘗試使用人工神經網絡(ANN),特別是使用 NeuroDiffEq package 來解決具有邊界條件的球面拉普拉斯方程:u(r=0)=u(r=1)=0 對於所有 theta 和 phi Python。 以下是相同的代碼

import numpy as np
import matplotlib.pyplot as plt
import torch
from neurodiffeq import diff 
from neurodiffeq.networks import FCNN 
from neurodiffeq.conditions import DirichletBVPSpherical
from neurodiffeq.solvers import SolverSpherical
from neurodiffeq.monitors import MonitorSpherical
from neurodiffeq.generators import Generator3D
%matplotlib notebook

laplace = lambda u, r, theta, phi: [
diff(((r**2)*diff(u,r,order=1)), r, order=1)/r**2 + 
diff((np.sin(theta))*diff(u,theta,order=1), theta, order=1)/((r**2)*(np.sin(theta))) +
diff(u,phi,order=2)/(r*np.sin(theta))**2
]

conditions = [
    DirichletBVPSpherical(r_0=0.0,f=0.0,r_1=1.0,g=0.0)
]

nets = [
FCNN(n_input_units=3, n_output_units=1, hidden_units=[512]),
]

monitor=MonitorSpherical(r_min=0.0,r_max=1.0,check_every=10,shape=(10,10,10),r_scale='linear',theta_min=0,theta_max=np.pi,phi_min=0,phi_max=2*np.pi)
monitor_callback = monitor.to_callback()

solver = SolverSpherical(
    pde_system=laplace,
    conditions=conditions,
    r_min=0.0,
    r_max=1.0,
    nets=nets,
    train_generator=Generator3D(grid=(10, 10, 10), xyz_min=(0.0, 0.0, 0.0), xyz_max=(1.0, 1.0, 1.0), method='equally-spaced'),
    valid_generator=Generator3D(grid=(10, 10, 10), xyz_min=(0.0, 0.0, 0.0), xyz_max=(1.0, 1.0, 1.0), method='equally-spaced-noisy'),
)

solver.fit(max_epochs=200, callbacks=[monitor_callback])

solution_neural_net_laplace = solver.get_solution()

我收到以下錯誤

mat1 and mat2 shapes cannot be multiplied (1000x1 and 3x512)

對於解決此錯誤的任何幫助,我將不勝感激。 提前致謝!

問題是mat1的形狀與mat2相乘不正確。 可能您使用的是 10x10x10 = 1000 的網格,因此請嘗試將其設為其他內容,即 8x8x8 = 512,或者您可以嘗試將輸入單位設為 1000,看看是否能解決問題。

也可能是n_input_units = 512n_input_units = 1000 ,與n_hidden_units = [something else] (取決於您在網格中所做的更改)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM