繁体   English   中英

将 sess.run 转换为 pytorch

[英]convert sess.run to pytorch

我正在尝试将代码从 tf 转换为 pytorch。 我被卡住的代码部分是这个 sess.run。 据我所知, pytorch 不需要它,但我找不到复制它的方法。 我附上代码。

特遣部队:

ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
    for i in range(ebnos_db.shape[0]):
        ebno_db = ebnos_db[i]
        bers_no_training[i] += sess.run(ber, feed_dict={
            batch_size: samples,
            noise_var: ebnodb2noisevar(ebno_db, coderate)
        })
bers_no_training /= epochs

samples 是一个 int32 并且 ebnodb2noisevar() 返回一个 float32。

TF 中的 BER 计算如下:

ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))

在 PT 中:

wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)

我认为BER计算得很好,但主要问题是我不知道如何将sess.run转换为PyTorch,也不完全了解它的function。

有谁能够帮我?

谢谢

您可以在 PyTorch 中执行相同操作,但在ber方面更容易:

ber = torch.mean((x != x_hat).float())

就足够了。

是的,PyTorch 不需要它,因为它基于动态图构造(与 Tensorflow 不同,它的 static 方法)。

tensorflow中, sess.run用于将值输入到创建的图形中; 这里名为batch_sizetf.Placeholder (图中的变量代表用户可以“注入”他的数据的节点)将被提供samples和带有noise_var ebnodb2noisevar(ebno_db, coderate)

将其转换为 PyTorch 通常很简单,因为您不需要 session 的任何类似图形的方法。 只需使用带有正确输入(如samplesnoise_var )的神经网络(或类似),就可以了。 您必须检查您的图表(因此ber是如何从batch_sizenoise_var的)并在 PyTorch 中重新实现它。

另外,在深入研究之前,请查看PyTorch 介绍性教程以了解该框架。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM