简体   繁体   English

TypeError:无法转换 numpy.str_ 类型的 np.ndarray。 唯一支持的类型是:float64、float32

[英]TypeError: can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32

I am working on a RNN algo to predict user's next location and I train it using torch but i get this error.我正在研究一个 RNN 算法来预测用户的下一个位置,我使用火炬对其进行训练,但我得到了这个错误。

and i get this error:我得到这个错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)

--> 203         total_loss += run(batch_user, batch_td, batch_ld, batch_loc, batch_dst, step=1)


<ipython-input-34-3a623cd33ef9> in run(user, td, ld, loc, dst, step)

--> 159     user = Variable(torch.from_numpy(np.asarray([user],dtype='<U32'))).type(ltype)

TypeError: can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32, float16, int64, int32, int16, int8, uint8, and bool.

my code is:我的代码是:

###############################################################################################
def run(user, td, ld, loc, dst, step):

    optimizer.zero_grad()

    seqlen = len(td)

    user = Variable(torch.from_numpy(np.asarray([user],dtype='<U32'))).type(ltype)

    #neg_loc = Variable(torch.FloatTensor(1).uniform_(0, len(poi2pos)-1).long()).type(ltype)
    #(neg_lati, neg_longi) = poi2pos.get(neg_loc.data.cpu().numpy()[0])
    rnn_output = h_0
    for idx in xrange(seqlen-1):
        td_upper = Variable(torch.from_numpy(np.asarray(up_time-td[idx],dtype='<U32'))).type(ftype)
        td_lower = Variable(torch.from_numpy(np.asarray(td[idx]-lw_time,dtype='<U32'))).type(ftype)
        ld_upper = Variable(torch.from_numpy(np.asarray(up_dist-ld[idx],dtype='<U32'))).type(ftype)
        ld_lower = Variable(torch.from_numpy(np.asarray(ld[idx]-lw_dist,dtype='<U32'))).type(ftype)
        location = Variable(torch.from_numpy(np.asarray(loc[idx],dtype='<U32'))).type(ltype)
        rnn_output = strnn_model(td_upper, td_lower, ld_upper, ld_lower, location, rnn_output)#, neg_lati, neg_longi, neg_loc, step)

Maybe that user is of type object.也许该用户的类型是 object。 So try to convert it as numerical using something like user.dtype("category").cat.codes所以尝试使用类似user.dtype("category").cat.codes的东西将其转换为数字

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM