简体   繁体   English

将 sess.run 转换为 pytorch

[英]convert sess.run to pytorch

I am trying to convert a code from tf to pytorch.我正在尝试将代码从 tf 转换为 pytorch。 The part of the code where I am stuck is this sess.run.我被卡住的代码部分是这个 sess.run。 As fas as I know, pytorch doesn't need it, but I don't find the way to replicate it.据我所知, pytorch 不需要它,但我找不到复制它的方法。 I attach you the code.我附上代码。

TF:特遣部队:

ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
    for i in range(ebnos_db.shape[0]):
        ebno_db = ebnos_db[i]
        bers_no_training[i] += sess.run(ber, feed_dict={
            batch_size: samples,
            noise_var: ebnodb2noisevar(ebno_db, coderate)
        })
bers_no_training /= epochs

samples is a int32 and ebnodb2noisevar() returns a float32. samples 是一个 int32 并且 ebnodb2noisevar() 返回一个 float32。

BER in TF is calculated as: TF 中的 BER 计算如下:

ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))

and in PT:在 PT 中:

wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)

I think BER is well computed, but the main problem is that I don't know how to convert sess.run into PyTorch, nor I completely understand its function.我认为BER计算得很好,但主要问题是我不知道如何将sess.run转换为PyTorch,也不完全了解它的function。

Can anybody help me?有谁能够帮我?

Thanks谢谢

You can do the same in PyTorch but easier when it comes to ber :您可以在 PyTorch 中执行相同操作,但在ber方面更容易:

ber = torch.mean((x != x_hat).float())

would be enough.就足够了。

Yes, PyTorch doesn't need it as it's based on dynamic graph construction (unlike Tensorflow with it's static approach).是的,PyTorch 不需要它,因为它基于动态图构造(与 Tensorflow 不同,它的 static 方法)。

In tensorflow sess.run is used to feed values into created graph;tensorflow中, sess.run用于将值输入到创建的图形中; here tf.Placeholder (variable in graph which represents a node where a user can "inject" his data) named batch_size will be fed with samples and noise_var with ebnodb2noisevar(ebno_db, coderate) .这里名为batch_sizetf.Placeholder (图中的变量代表用户可以“注入”他的数据的节点)将被提供samples和带有noise_var ebnodb2noisevar(ebno_db, coderate)

Translating this to PyTorch is usually straightforward as you don't need any graph-like approaches with session.将其转换为 PyTorch 通常很简单,因为您不需要 session 的任何类似图形的方法。 Just use your neural network (or a-like) with correct input (like samples and noise_var ) and you are fine.只需使用带有正确输入(如samplesnoise_var )的神经网络(或类似),就可以了。 You have to check your graph (so how ber is constructed from batch_size and noise_var ) and reimplement it in PyTorch.您必须检查您的图表(因此ber是如何从batch_sizenoise_var的)并在 PyTorch 中重新实现它。

Also, please check PyTorch introductory tutorials to get a feel of the framework before diving into it.另外,在深入研究之前,请查看PyTorch 介绍性教程以了解该框架。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM