繁体   English   中英

需要帮助理解 GPT-2s 源代码中的 Python 函数

[英]Need Help understanding a Python Function in GPT-2s Source Code

我正在 Github 中浏览 GPT-2 的源代码。 我试图了解这一切是如何运作的。 我被一个函数难住了,我希望有人能向我解释发生了什么。

https://github.com/nshepperd/gpt-2/blob/finetuning/src/model.py

代码可以在上面链接中的 model.py 中找到。 这里具体是:

def shape_list(x):
   """Deal with dynamic shape in tensorflow cleanly."""
   static = x.shape.as_list()
   dynamic = tf.shape(x)
   return [dynamic[i] if s is None else s for i, s in enumerate(static)]

我对 Tensorflow.Shape() 返回的内容以及静态和动态形状之间的差异进行了一些研究: 如何理解 TensorFlow 中的静态形状和动态形状?

我也通读了这一系列文章: https : //medium.com/analytics-vidhya/understanding-the-gpt-2-source-code-part-3-9796a5a5cc7c

尽管读了这么多,我还是不完全确定发生了什么。 我不清楚的是最后一句话:

return [dynamic[i] if s is None else s for i, s in enumerate(static)]

它到底在说什么? 我的猜测是函数的目的是确定 X 的值是否已经定义。 如果没有,则返回静态形状,如果有,则返回动态形状。

我离这儿很远吗?

您的问题不在于 Tensorflow,而在于 Python 中的列表推导式,这是一种基于其他可迭代对象定义列表的更 Python 化的方式。

最后一条语句(几乎*)相当于:

ret = []
for i, s in enumerate(static):
  if s is None:
    ret.append(dynamic[i])
  else:
    ret.append(s)
return ret

* : 关于上面的“几乎”,其实推导的效率更高,因为内部是为整个结果预先分配内存,而循环每次迭代都会appends s,从而导致扩展列表时多次分配,比较慢.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM