[英]convert sess.run to pytorch
我正在尝试将代码从 tf 转换为 pytorch。 我被卡住的代码部分是这个 sess.run。 据我所知, pytorch 不需要它,但我找不到复制它的方法。 我附上代码。
特遣部队:
ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
for i in range(ebnos_db.shape[0]):
ebno_db = ebnos_db[i]
bers_no_training[i] += sess.run(ber, feed_dict={
batch_size: samples,
noise_var: ebnodb2noisevar(ebno_db, coderate)
})
bers_no_training /= epochs
samples 是一个 int32 并且 ebnodb2noisevar() 返回一个 float32。
TF 中的 BER 计算如下:
ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))
在 PT 中:
wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)
我认为BER计算得很好,但主要问题是我不知道如何将sess.run转换为PyTorch,也不完全了解它的function。
有谁能够帮我?
谢谢
您可以在 PyTorch 中执行相同操作,但在ber
方面更容易:
ber = torch.mean((x != x_hat).float())
就足够了。
是的,PyTorch 不需要它,因为它基于动态图构造(与 Tensorflow 不同,它的 static 方法)。
在tensorflow
中, sess.run
用于将值输入到创建的图形中; 这里名为batch_size
的tf.Placeholder
(图中的变量代表用户可以“注入”他的数据的节点)将被提供samples
和带有noise_var
ebnodb2noisevar(ebno_db, coderate)
。
将其转换为 PyTorch 通常很简单,因为您不需要 session 的任何类似图形的方法。 只需使用带有正确输入(如samples
和noise_var
)的神经网络(或类似),就可以了。 您必须检查您的图表(因此ber
是如何从batch_size
和noise_var
的)并在 PyTorch 中重新实现它。
另外,在深入研究之前,请查看PyTorch 介绍性教程以了解该框架。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.