简体   繁体   中英

Python Machine Learning Beginner Question

I just started to studying machine learning and I saw a code. I don't know anything about it. Also I don't know how to search it... I am stuck here please help. Here is the example code:

from sklearn import datasets, model_selection
import matplotlib.pyplot as plt
import numpy

X, y = datasets.load_diabetes(return_X_y=True)
X = X[:, numpy.newaxis, 2] # I didn't understand this part
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.33)
plt.scatter(X_test, y_test,  color='black')
plt.show()

Where does 2 come from? What is np.newaxis (I think this is a method which returns None but I am not sure) Also what are these parameters separated with commas inside square brackets? Please tell me the name of it or explain what it is. Thank you:)

This is called indexing, or sometimes slicing, you can read more about it on numpy's user guide .

The 2 is an arbitrary index chosen by someone who wrote the code, which basically slices all the 3rd element of each row (the bmi feature according to sklearn's diabetes dataset documentation )

np.newaxis is a constant from numpy used to increase the dimension of an ndarray. (read more here as mentioned in the comment)

Therefore, the code tries to select only one feature for the training data from the 10 available features from the dataset before splitting the constructed dataset to train and test data.

Hope this helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM