简体   繁体   中英

convert sess.run to pytorch

I am trying to convert a code from tf to pytorch. The part of the code where I am stuck is this sess.run. As fas as I know, pytorch doesn't need it, but I don't find the way to replicate it. I attach you the code.

TF:

ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
    for i in range(ebnos_db.shape[0]):
        ebno_db = ebnos_db[i]
        bers_no_training[i] += sess.run(ber, feed_dict={
            batch_size: samples,
            noise_var: ebnodb2noisevar(ebno_db, coderate)
        })
bers_no_training /= epochs

samples is a int32 and ebnodb2noisevar() returns a float32.

BER in TF is calculated as:

ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))

and in PT:

wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)

I think BER is well computed, but the main problem is that I don't know how to convert sess.run into PyTorch, nor I completely understand its function.

Can anybody help me?

Thanks

You can do the same in PyTorch but easier when it comes to ber :

ber = torch.mean((x != x_hat).float())

would be enough.

Yes, PyTorch doesn't need it as it's based on dynamic graph construction (unlike Tensorflow with it's static approach).

In tensorflow sess.run is used to feed values into created graph; here tf.Placeholder (variable in graph which represents a node where a user can "inject" his data) named batch_size will be fed with samples and noise_var with ebnodb2noisevar(ebno_db, coderate) .

Translating this to PyTorch is usually straightforward as you don't need any graph-like approaches with session. Just use your neural network (or a-like) with correct input (like samples and noise_var ) and you are fine. You have to check your graph (so how ber is constructed from batch_size and noise_var ) and reimplement it in PyTorch.

Also, please check PyTorch introductory tutorials to get a feel of the framework before diving into it.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM