簡體   English   中英

在二維列表中獲取正確的值

[英]Get the correct value in a 2d list

我有一個多項式回歸 plot 並且我正在嘗試使用下一個 X 值找到預測值 y。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
import json
import matplotlib.pyplot as plt
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures

with open('/Users/aus10/Desktop/PGA/Data_Cleanup/Combined_Player_Stats.json') as json_file:
    players_data = json.load(json_file)

for obj in players_data:
    obj['Scrambling_List'] = [i for i in obj['Scrambling_List'] if i]

for obj in players_data:
    def create_2d_lst(lst):
        try:
            if len(lst) < 1:
                return [0, 0]
            else:
                return [[i, j] for i, j in enumerate(lst)]
        except:
                pass
    try:     
        scrambling = create_2d_lst(obj['Scrambling_List'])
        total_putts_GIR = create_2d_lst(obj['Total_Putts_GIR_List'])
        SG_Putting = create_2d_lst(obj['SG_Putting_List'])
    except Exception:
        pass

    data = scrambling
    X = np.array(data)[:,0].reshape(-1,1)
    y = np.array(data)[:,1].reshape(-1,1)

    poly_reg = PolynomialFeatures(degree=4)

    X_poly = poly_reg.fit_transform(X)

    pol_reg = LinearRegression()
    pol_reg.fit(X_poly, y)

    predicted_y = poly_reg.fit_transform(X)
    m = pol_reg.coef_
    c = pol_reg.intercept_

    prediction_value = (len(X) + 1)

    prediction = pol_reg.predict(poly_reg.fit_transform([[prediction_value]]))

    def viz_polymonial():
        plt.scatter(X, y, color='red')
        plt.plot(X, pol_reg.predict(poly_reg.fit_transform(X)), color='blue')
        plt.plot(prediction, marker='x', color='green')
        plt.title('Projected Scrambling Percentage')
        plt.xlabel('Tournaments')
        plt.ylabel('Scrambling Percentage')
        plt.show()
        return

    viz_polymonial()

    

    print(obj['Name'], prediction)

當我使用prediction = prediction_value = (len(X) + 1)prediction = pol_reg.predict(poly_reg.fit_transform([[prediction_value]]))我應該得到X的下一個值,但它返回 0,當它應該是 len(X) + 1。我需要將正確的X值設置為預測值。 我不確定為什么它為零,因為當我打印預測值時,我得到了正確的值

這是 json 的副本

[
  {
    "Name": "Aaron Baddeley",
    "Tournaments": [
      {
        "Scrambling": 71.43,
        "Total_Putts_GIR": 75,
        "SG_Putting": 0.31,
        "Tournament": "Safeway_Open",
        "Date": "08-26-2019"
      },
      {
        "Scrambling": 55.56,
        "Total_Putts_GIR": 92,
        "SG_Putting": 0.03,
        "Tournament": "Shriners_Hospital_for_Children_Open",
        "Date": "10-08-2019"
      },
      {
        "Scrambling": 40,
        "Total_Putts_GIR": 47,
        "SG_Putting": -0.14,
        "Tournament": "Houston",
        "Date": "10-10-2019"
      },
      {
        "Scrambling": 71.43,
        "Total_Putts_GIR": 93,
        "SG_Putting": -0.37,
        "Tournament": "Waste_Management",
        "Date": "01-30-2020"
      },
      {
        "Scrambling": 75,
        "Total_Putts_GIR": 29,
        "SG_Putting": 0.69,
        "Tournament": "The_Genesis",
        "Date": "02-13-2020"
      },
      {
        "Scrambling": 71.43,
        "Total_Putts_GIR": 38,
        "SG_Putting": -0.82,
        "Tournament": "RBC_Heritage",
        "Date": "06-18-2020"
      },
      {
        "Scrambling": 50,
        "Total_Putts_GIR": 30,
        "SG_Putting": 0.88,
        "Tournament": "Travelers",
        "Date": "06-25-2020"
      },
      {
        "Scrambling": 42.86,
        "Total_Putts_GIR": 53,
        "SG_Putting": -1.18,
        "Tournament": "Rocket_Mortgage",
        "Date": "07-02-2020"
      },
      {
        "Scrambling": 43.75,
        "Total_Putts_GIR": 33,
        "SG_Putting": 1.41,
        "Tournament": "Workday",
        "Date": "07-09-2020"
      }
    ],
    "Scrambling_List": [
      71.43,
      55.56,
      40,
      71.43,
      75,
      71.43,
      50,
      42.86,
      43.75
    ],
    "Total_Putts_GIR_List": [
      75,
      92,
      47,
      93,
      29,
      38,
      30,
      53,
      33
    ],
    "SG_Putting_List": [
      0.31,
      0.03,
      -0.14,
      -0.37,
      0.69,
      -0.82,
      0.88,
      -1.18,
      1.41
    ]
  }]

弄清楚了。 plot 中沒有 X 值,因此它被自動設置為 0。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
import json
import matplotlib.pyplot as plt
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures

with open('/Users/aus10/Desktop/PGA/Data_Cleanup/Combined_Player_Stats.json') as json_file:
    players_data = json.load(json_file)

for obj in players_data:
    obj['Scrambling_List'] = [i for i in obj['Scrambling_List'] if i]

for obj in players_data:
    def create_2d_lst(lst):
        try:
            if len(lst) < 1:
                return [0, 0]
            else:
                return [[i, j] for i, j in enumerate(lst)]
        except:
                pass
    try:     
        scrambling = create_2d_lst(obj['Scrambling_List'])
        total_putts_GIR = create_2d_lst(obj['Total_Putts_GIR_List'])
        SG_Putting = create_2d_lst(obj['SG_Putting_List'])
    except Exception:
        pass

    data = scrambling
    X = np.array(data)[:,0].reshape(-1,1)
    y = np.array(data)[:,1].reshape(-1,1)

    poly_reg = PolynomialFeatures(degree=4)

    X_poly = poly_reg.fit_transform(X)

    pol_reg = LinearRegression()
    pol_reg.fit(X_poly, y)

    predicted_y = poly_reg.fit_transform(X)
    m = pol_reg.coef_
    c = pol_reg.intercept_

    prediction = pol_reg.predict(poly_reg.fit_transform([[len(X)+1]]))

    def viz_polymonial():
        plt.scatter(X, y, color='red')
        plt.plot(X, pol_reg.predict(poly_reg.fit_transform(X)), color='blue')
        plt.plot(len(X)+1, pol_reg.predict(poly_reg.fit_transform([[len(X)+1]])), marker='x', color='green')
        plt.title('Projected Scrambling Percentage')
        plt.xlabel('Tournaments')
        plt.ylabel('Scrambling Percentage')
        plt.show()
        return

    viz_polymonial()

    print(obj['Name'], prediction)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM