简体   繁体   中英

Is there any replacement for tf.random_gamma in pytorch?

I'm converting a TensorFlow repository to PyTorch code. I came across this line of code:

tf.squeeze(tf.random_gamma(shape =(self.n_sample,),alpha=self.alpha+tf.to_float(self.B)))

I would like to know the equivalent of tf.random_gamma in PyTorch. I think torch.distributions.gamma.Gamma doesn't work the same way.

It looks like torch.distributions.gamma.Gamma can be used in this case. Here is an example:

import torch
from torch.distributions.gamma import Gamma


def random_gamma(shape, alpha, beta=1.0):
  alpha = torch.ones(shape) * torch.tensor(alpha)
  beta = torch.ones(shape) * torch.tensor(beta)
  gamma_distribution = Gamma(alpha, beta)

  return gamma_distribution.sample()

print(random_gamma(shape=(10,), alpha=3.0))

Output:

tensor([2.7673, 1.5498, 6.5191, 5.2923, 3.3204, 3.9286, 1.4163, 1.2400, 3.9661, 1.7663])

The difference is that torch.distributions.gamma.Gamma requires complete tensors for alpha and beta instead of shape+values like it is in TF. Also, TF version has default value 1 for beta which I tried to imitate in the example code.

It makes sense to create distribution instance once though in case if the function will be used multiple times.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM