简体   繁体   English

在theano中使用grad时出错

[英]Error using grad in theano

I am following a tutorial that shows how to implement logistic regression using Theano. 我正在关注一个教程,该教程显示了如何使用Theano实现逻辑回归。 The listed line is giving me an error. 列出的行给我一个错误。 I don't know how to fix it. 我不知道该如何解决。

from theano import tensor
TS = tensor.matrix('training-set')
W = tensor.matrix('weights')
E = tensor.matrix('expected')
O = tensor.dot(TS,W)
def_err = ((E-O)**2).sum()
e = function([W,TS,E],def_err)
grad_err = function([W,TS,E],grad(e,W))

This is the error I am getting: 这是我得到的错误:

\in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    428             raise AssertionError("cost and known_grads can't both be None.")
    429 
--> 430     if cost is not None and isinstance(cost.type, NullType):
    431         raise ValueError("Can't differentiate a NaN cost."
    432                          "cost is NaN because " +

AttributeError: 'Function' object has no attribute 'type'

In line grad_err = function([W,TS,E],grad(e,W)) you want to compute gradient of error 'def_err' wrt 'W', but you are passing a function 'e' to grad(..) without the list of inputs, this will never work. grad_err = function([W,TS,E],grad(e,W))您要计算误差'def_err'wrt'W'的梯度,但是您要将函数'e'传递给grad(.. )没有输入列表,则将永远无法使用。 Also please note that TS, W, E, O etc are tensor/symbolic variables which are general expressions and need to be provided with extra input to determine their value. 还请注意,TS,W,E,O等是张量/符号变量,它们是一般表达式,需要提供额外的输入以确定它们的值。

I would recommend going through the following tutorial for logistic regression, If you have just started Theano then these tutorials will definitely help you to get started. 我建议您阅读以下有关Logistic回归的教程 ,如果您刚刚启动Theano,那么这些教程肯定会帮助您入门。

This should work: 这应该工作:

from theano import tensor, function, grad

TS = tensor.matrix('training-set')
W = tensor.matrix('weights')
E = tensor.matrix('expected')
O = tensor.dot(TS,W)
def_err = ((E-O)**2).sum()
e = function([W,TS,E],def_err)
grad_err = function([W,TS,E],grad(def_err,W))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM