簡體   English   中英

如何在 GPU 中運行機器學習算法

[英]How to run Machine Learning algorithms in GPU

我使用反向消除算法來減少特征。 但關鍵是有大量的特征和樣本,它在 CPU 上運行並且相當慢。

我怎么能像訓練深度學習一樣在 GPU 上運行它的多線程。 這是我的代碼

import pandas as pd
data = pd.read_csv('data.csv')
X = data.drop(['Path','id','label'], axis=1)
y = data['label']

from mlxtend.feature_selection import SequentialFeatureSelector as sfs
from sklearn.linear_model import LinearRegression
lreg = LinearRegression()
new_sfs = sfs(lreg, k_features=1600, forward=False, verbose=1, scoring='neg_mean_squared_error')
new_sfs = new_sfs.fit(X, y)
feat_names = list(sfs1.k_feature_names_)

你考慮過使用JAX嗎? 您可以輕松地使用 numpy 例程(用 jax 編寫)將計算卸載到 GPU(或任何其他設備,如 TPU)。 這是 JAX 中基本線性回歸的示例:

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

@jax.jit
def gpu_linear(x,y):
    return jnp.linalg.lstsq(x, y, rcond=None)[0]

# Setting key for random number generation
key_ = jax.random.PRNGKey(seed=1)

# Generating toy example
key_, subkey_ = jax.random.split(key=key_)
x = jax.random.normal(key=subkey_, shape=(100000,))
x = jnp.vstack([x,jnp.ones(len(x))]).T

key_, subkey_ = jax.random.split(key=key_)
y = 1.6*x[:,0] + jax.random.normal(key=subkey_, shape=(100000,))


m, c = gpu_linear(x,y)

plt.plot(x[:,0], y, 'o', label='original data')
plt.plot(x[:,0], m*x[:,0] + c, 'r', label='fitted line')
plt.legend()
plt.show()

在此處輸入圖像描述

我使用%timeit對上述實現進行計時:

%timeit gpu_linear(x,y)
243 µs ± 691 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

這是在 numpy 中實現的相同的東西(僅供參考)。

import numpy as np
import matplotlib.pyplot as plt

def cpu_linear(x,y):
    return np.linalg.lstsq(x, y, rcond=None)[0]

# Generating toy example
x = np.random.normal(0, 1, 100000)
x = np.vstack([x,jnp.ones(len(x))]).T

y = 1.6*x[:,0] + np.random.normal(0, 1, 100000)


m, c = cpu_linear(x,y)

plt.plot(x[:,0], y, 'o', label='original data')
plt.plot(x[:,0], m*x[:,0] + c, 'r', label='fitted line')
plt.legend()
plt.show()

%timeit cpu_linear(x,y)
1.52 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

正如您所看到的,當我使用 JAX 時速度會有所提高,因為我的所有計算都已卸載到 GPU。 即使你不使用 GPU,JAX 也可以緩存你的函數(當我使用@jax.jit裝飾函數時),你也會看到 CPU 的加速(與幼稚的 numpy 或 scikit learn 相比)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM