繁体   English   中英

Python 无限运行的多处理代码

[英]Python multiprocessing code running infinitely

我正在尝试使用 sklearn 和 python 的内置多处理库同时训练 2 个模型。

def train_model(model, X, y):
    model.fit(X, y)
    return model

from multiprocessing import Process

p1 = Process(target = train_model, args = (dt, X_train, y_train))
p2 = Process(target = train_model, args = (lr, X_train, y_train))

p1.start()
p2.start()

p1.join()
p2.join()

但是,在运行这段代码后,它会继续无限运行。 单独训练这两个模型不会超过几秒钟。

如果我的方法是错误的,我该如何并行训练 2 个模型?

编辑:Python 版本为 3.8.0。 我在 Windows 10 上的 Jupyter Notebook 上运行这段代码。

编辑 2:问题似乎出在 Jupyter Notebook 上。 相同的代码在 Google Colab 上运行没有任何问题。

编辑 3:我现在正尝试使用我的终端运行这段代码

dt = DecisionTreeClassifier(class_weight='balanced')
lr = LogisticRegression(class_weight='balanced')


def train_model(model, X, y):
    model.fit(X, y)
    return model


p1 = Process(target=train_model, args=(dt, X_train, y_train))
p2 = Process(target=train_model, args=((lr, X_train, y_train)))

if __name__ == '__main__':
    p1.start()
    p2.start()
    p1.join()
    p2.join()

    dt_pred = dt.predict(X_test)
    lr_pred = lr.predict(X_test)

    print("Classification report for Decision Tree:",classification_report(y_test,dt_pred))
    print("Classification report for Logistic Regression", classification_report(y_test, lr_pred))

并得到以下错误

Traceback (most recent call last):
  File "D:/Bennett/HPC/E19CSE058_Lab3/E19CSE058_Lab3_Pt2.py", line 33, in <module>
    dt_pred = dt.predict(X_test)
  File "E:\Anaconda3\lib\site-packages\sklearn\tree\_classes.py", line 436, in predict
    check_is_fitted(self)
  File "E:\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "E:\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 1041, in check_is_fitted
    raise NotFittedError(msg % {'name': type(estimator).__name__})
sklearn.exceptions.NotFittedError: This DecisionTreeClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

似乎通过多处理完成的培训并没有反映在流程之外。 我该如何应对?

亚伦有正确的答案。 在 Windows 上,每个进程都从头开始运行您的脚本,这将启动另外两个进程,每个进程又启动两个进程,等等。任何必须仅在主进程中运行的东西都需要受到"__main__"的保护测试:

from multiprocessing import Process

def train_model(model, X, y):
    model.fit(X, y)
    return model

def main():
    p1 = Process(target = train_model, args = (dt, X_train, y_train))
    p2 = Process(target = train_model, args = (lr, X_train, y_train))

    p1.start()
    p2.start()

    p1.join()
    p2.join()

if __name__ == "__main__":
    main()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM