[英]Python AST - finding particular named function calls
我正在嘗試分析一些 Python 代碼,以確定調用特定函數的位置以及正在傳遞的 arguments。
例如,假設我有一個包含model.fit(X_train,y_train)
的 ML 腳本。 我想在腳本中找到這一行,確定適合的 object(即model
),並將X_train
和y_train
識別為 arguments(以及任何其他)。
我是 AST 的新手,所以我不知道如何有效地做到這一點。
到目前為止,我已經能夠通過遍歷子節點列表(使用ast.iter_child_nodes
)找到有問題的行,直到我到達ast.Call
object,然后調用它的func.attr
,返回"fit"
。 我還可以使用args
獲得"X_train"
和"y_train"
。
問題是我必須提前知道它在哪里才能這樣做,所以它不是特別有用。 這個想法是讓它自動獲取我正在尋找的信息。
此外,我無法找到一種方法來確定model
是什么叫fit
。
您可以遍歷ast
並搜索名稱fit
的ast.Call
節點:
import ast
def fit_calls(tree):
for i in ast.walk(tree):
if isinstance(i, ast.Call) and isinstance(i.func, ast.Attribute) and i.func.attr == 'fit':
yield {'model_obj_str':ast.unparse(i.func.value),
'model_obj_ast':i.func.value,
'args':[ast.unparse(j) for j in i.args],
'kwargs':{j.arg:ast.unparse(j.value) for j in i.keywords}}
測試樣品:
#https://www.tensorflow.org/api_docs/python/tf/keras/Model
sample_1 = """
model = tf.keras.models.Model(
inputs=inputs, outputs=[output_1, output_2])
model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
model.fit(x, (y, y))
model.metrics_names
"""
sample_2 = """
optimizer = tf.keras.optimizers.SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
model.fit(dataset, epochs=2, steps_per_epoch=10)
"""
sample_3 = """
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2))
_ = model.fit(x, y, verbose=0)
"""
#https://scikit-learn.org/stable/developers/develop.html
sample_4 = """
estimator = estimator.fit(data, targets)
"""
sample_5 = """
y_predicted = SVC(C=100).fit(X_train, y_train).predict(X_test)
"""
print([*fit_calls(ast.parse(sample_1))])
print([*fit_calls(ast.parse(sample_2))])
print([*fit_calls(ast.parse(sample_3))])
print([*fit_calls(ast.parse(sample_4))])
print([*fit_calls(ast.parse(sample_5))])
Output:
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x1007737c0>,
'args': ['x', '(y, y)'], 'kwargs': {}}]
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x1007731f0>,
'args': ['dataset'], 'kwargs': {'epochs': '2', 'steps_per_epoch': '10'}}]
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x100773d00>,
'args': ['x', 'y'], 'kwargs': {'verbose': '0'}}]
[{'model_obj_str': 'estimator', 'model_obj_ast': <ast.Name object at 0x100773ca0>,
'args': ['data', 'targets'], 'kwargs': {}}]
[{'model_obj_str': 'SVC(C=100)', 'model_obj_ast': <ast.Call object at 0x100773130>,
'args': ['X_train', 'y_train'], 'kwargs': {}}]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.