[英]Neural Network Loss Inaccurate
I trained a model for neural network to find the roots of the quadratic equation (discriminant >= 0) But when I checked the same on my example, even the loss is small it is showing a far from exact answer. 我为神经网络训练了一个模型,以找到二次方程的根(判别式> = 0),但是当我在示例中进行检查时,即使损失很小,也显示出远非确切的答案。
Loss graph: 损耗图:
My example: 我的例子:
a = 1 b = -2 c = -24 model.predict(np.array([[a/max,b/max,c/max]])) * max Out[421]: array([[-15.218947 , -1.3733944]], dtype=float32) #but should be 6 and -4
Look here please: 请看这里:
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.utils import np_utils
from keras.layers import Dropout
x_in = np.array([]).reshape(0,3)
x_answer = np.array([]).reshape(0,2)
for i in range(300):
a = np.random.randint(-1000,1000)
b = np.random.randint(-1000,1000)
c = np.random.randint(-1000,1000)
D = np.power(b,2)-4*a*c
if(a != 0):
if(D >= 0):
x1 = (-b+np.sqrt(D))/(2*a)
x2 = (-b-np.sqrt(D))/(2*a)
x_in = np.concatenate((x_in,[[a,b,c]]))
x_answer = np.concatenate((x_answer,[[x1,x2]]))
np.random.seed()
NB_EPOCH = 300
VERBOSE = 1
x_in = np.asarray(x_in, dtype=np.float32)
x_answer = np.asarray(x_answer, dtype=np.float32)
min_in = np.nanmin(x_in)
min_answ = np.nanmin(x_answer)
min = -1000 #np.min(np.array([min_in,min_answ]))
max_in = np.nanmax(x_in)
max_answ = np.nanmax(x_answer)
max = 1000 #np.max(np.array([max_in,max_answ]))
x_in /= max
x_answer /= max
model = Sequential()
model.add(Dense(30, input_dim = 3, activation='relu'))
#model.add(Dropout(0.2))
model.add(Dense(40, activation='softmax'))
#model.add(Dropout(0.2))
model.add(Dense(50, activation='linear'))
model.add(Dense(2))
model.compile(loss='mse', optimizer='adam')
history = model.fit(x_in, x_answer, epochs=NB_EPOCH, verbose=VERBOSE)
UPDATE: 更新:
what to do? 该怎么办?
I think that 300 training points is too few for a parameter space of (2000)**3 possible parameter values of a, b and c. 我认为对于(2000)** 3个可能的a,b和c参数值的参数空间,300个训练点太少了。 You could try to give it more training data.
您可以尝试为其提供更多培训数据。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.