繁体   English   中英

T.hessian在theano中给出NotImplementedError()

[英]T.hessian gives NotImplementedError() in theano

它如何运作

    g_W = T.grad(cost=cost, wrt=classifier.vparamW) 

而这

    H_W=T.hessian(cost=cost, wrt=classifier.vparamW)

给出NotImplementedError()可能是这种成本函数中的问题:

    -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) 

这里y是从0到n-1的类标签的向量,

    self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) 

我无法使用提供的有限代码来重现此问题。 但是,这是T.gradT.hessian的完整工作演示。

import numpy
import theano
import theano.tensor as T

x = T.matrix()
w_flat = theano.shared(numpy.random.randn(3, 2).astype(theano.config.floatX).flatten())
w = w_flat.reshape((3, 2))
cost = T.pow(theano.dot(x, w), 2).sum()
g_w = T.grad(cost=cost, wrt=[w])
h_w = T.hessian(cost=cost, wrt=[w_flat])
f = theano.function([x], outputs=g_w + h_w)
for output in f(numpy.random.randn(4, 3).astype(theano.config.floatX)):
    print output.shape, '\n', output

请注意, T.hessianwrt值必须是一个向量。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM