簡體   English   中英

使用python將CSV文件轉換為LIBSVM兼容數據文件

[英]Converting CSV file to LIBSVM compatible data file using python

我正在使用libsvm做一個項目,我正在准備我的數據來使用lib。 如何將CSV文件轉換為LIBSVM兼容數據?

CSV文件: https//github.com/scikit-learn/scikit-learn/blob/master/sklearn/datasets/data/iris.csv

在頻率問題中:

如何將其他數據格式轉換為LIBSVM格式?

這取決於您的數據格式。 一種簡單的方法是在libsvm matlab / octave接口中使用libsvmwrite。 以UCI機器學習庫中的CSV(逗號分隔值)文件為例。 我們下載SPECTF.train。 標簽位於第一列。 以下步驟以libsvm格式生成文件。

matlab> SPECTF = csvread('SPECTF.train'); % read a csv file
matlab> labels = SPECTF(:, 1); % labels from the 1st column
matlab> features = SPECTF(:, 2:end); 
matlab> features_sparse = sparse(features); % features must be in a sparse matrix
matlab> libsvmwrite('SPECTFlibsvm.train', labels, features_sparse);
The tranformed data are stored in SPECTFlibsvm.train.
Alternatively, you can use convert.c to convert CSV format to libsvm format.

但我不想使用matlab,我使用python。

我也使用JAVA找到了這個解決方案

任何人都可以推薦一種解決這個問題的方法嗎?

您可以使用csv2libsvm.pycsv轉換為libsvm data

python csv2libsvm.py iris.csv libsvm.data 4 True

其中4表示target indexTrue表示csv表示target index

最后,您可以將libsvm.data作為

0 1:5.1 2:3.5 3:1.4 4:0.2
0 1:4.9 2:3.0 3:1.4 4:0.2
0 1:4.7 2:3.2 3:1.3 4:0.2
0 1:4.6 2:3.1 3:1.5 4:0.2
...

來自iris.csv

150,4,setosa,versicolor,virginica
5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
...

csv2libsvm.py不能與Python3一起使用,也不支持標簽目標(字符串目標),我稍微修改了它。 現在它應該與Python3以及標簽目標一起使用。 我是Python的新手,所以我的代碼可能不是最佳實踐,但我希望可以幫助某人。

#!/usr/bin/env python

"""
Convert CSV file to libsvm format. Works only with numeric variables.
Put -1 as label index (argv[3]) if there are no labels in your file.
Expecting no headers. If present, headers can be skipped with argv[4] == 1.

"""

import sys
import csv
import operator
from collections import defaultdict

def construct_line(label, line, labels_dict):
    new_line = []
    if label.isnumeric():
        if float(label) == 0.0:
            label = "0"
    else:
        if label in labels_dict:
            new_line.append(labels_dict.get(label))
        else:
            label_id = str(len(labels_dict))
            labels_dict[label] = label_id
            new_line.append(label_id)

    for i, item in enumerate(line):
        if item == '' or float(item) == 0.0:
            continue
        elif item=='NaN':
            item="0.0"
        new_item = "%s:%s" % (i + 1, item)
        new_line.append(new_item)
    new_line = " ".join(new_line)
    new_line += "\n"
    return new_line

# ---

input_file = sys.argv[1]
try:
    output_file = sys.argv[2]
except IndexError:
    output_file = input_file+".out"


try:
    label_index = int( sys.argv[3] )
except IndexError:
    label_index = 0

try:
    skip_headers = sys.argv[4]
except IndexError:
    skip_headers = 0

i = open(input_file, 'rt')
o = open(output_file, 'wb')

reader = csv.reader(i)

if skip_headers:
    headers = reader.__next__()

labels_dict = {}
for line in reader:
    if label_index == -1:
        label = '1'
    else:
        label = line.pop(label_index)

    new_line = construct_line(label, line, labels_dict)
    o.write(new_line.encode('utf-8'))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM