简体   繁体   English

python中决策树实现的数据集错误

[英]Decision tree implementation in python for dataset error

I have done the following code for python decision tree algorithm implementation: 我为python决策树算法实现完成了以下代码:

from csv import reader

def load_csv(filename):
    file = open(filename, "rb")
    lines = reader(file)
    dataset = list(lines)
    return dataset

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
    # count all samples at split point
    n_instances = float(sum([len(group) for group in groups]))
    # sum weighted Gini index for each group
    gini = 0.0
    for group in groups:
        size = float(len(group))
        # avoid divide by zero
        if size == 0:
            continue
        score = 0.0
        # score the group based on the score for each class
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p * p
        # weight the group score by its relative size
        gini += (1.0 - score) * (size / n_instances)
    return gini

# Select the best split point for a dataset
def get_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
    left, right = node['groups']
    del(node['groups'])
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    # process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'], max_depth, min_size, depth+1)
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'], max_depth, min_size, depth+1)

# Build a decision tree
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root

# Print a decision tree
def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['left'], depth+1)
        print_tree(node['right'], depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))

# load and prepare data
filename = 'spine.csv'
dataset = load_csv(filename)

tree = build_tree(dataset, 1, 1)
print_tree(tree)

The dataset contains a total of 13 attributes related to spine, the data mining algorithm is used to find whether a persons spine is normal or abnormal 数据集总共包含13个与脊柱相关的属性,数据挖掘算法用于查找人的脊柱是正常还是异常

For the dataset spine.csv link given below: 对于下面给出的数据集spine.csv链接:

https://drive.google.com/file/d/1wubSDVMD2uhJXYJNNbCBHW5Gerp-YDlx/view?usp=sharing https://drive.google.com/file/d/1wubSDVMD2uhJXYJNNbCBHW5Gerp-YDlx/view?usp=sharing

Its shows the following errors: 它显示以下错误:

Traceback (most recent call last):
  File "spine.py", line 101, in <module>
    print_tree(tree)
  File "spine.py", line 90, in print_tree
    print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
TypeError: float argument required, not str

node['value'] is a sting, not a float. node['value']是字符串,而不是浮点数。

Try using %s in your format string instead. 请尝试在格式字符串中使用%s

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM