简体   繁体   English

在HDF5(PyTables)中存储numpy稀疏矩阵

[英]Storing numpy sparse matrix in HDF5 (PyTables)

I am having trouble storing a numpy csr_matrix with PyTables. 我在使用PyTables存储numpy csr_matrix时遇到问题。 I'm getting this error: 我收到这个错误:

TypeError: objects of type ``csr_matrix`` are not supported in this context, sorry; supported objects are: NumPy array, record or scalar; homogeneous list or tuple, integer, float, complex or string

My code: 我的代码:

f = tables.openFile(path,'w')

atom = tables.Atom.from_dtype(self.count_vector.dtype)
ds = f.createCArray(f.root, 'count', atom, self.count_vector.shape)
ds[:] = self.count_vector
f.close()

Any ideas? 有任何想法吗?

Thanks 谢谢

The answer by DaveP is almost right... but can cause problems for very sparse matrices: if the last column(s) or row(s) are empty, they are dropped. DaveP的答案几乎是正确的......但是可能会导致非常稀疏的矩阵出现问题:如果最后一列或一行是空的,它们会被丢弃。 So to be sure that everything works, the "shape" attribute must be stored too. 所以为了确保一切正常,“shape”属性也必须存储。

This is the code I regularly use: 这是我经常使用的代码:

import tables as tb
from numpy import array
from scipy import sparse

def store_sparse_mat(m, name, store='store.h5'):
    msg = "This code only works for csr matrices"
    assert(m.__class__ == sparse.csr.csr_matrix), msg
    with tb.openFile(store,'a') as f:
        for par in ('data', 'indices', 'indptr', 'shape'):
            full_name = '%s_%s' % (name, par)
            try:
                n = getattr(f.root, full_name)
                n._f_remove()
            except AttributeError:
                pass

            arr = array(getattr(m, par))
            atom = tb.Atom.from_dtype(arr.dtype)
            ds = f.createCArray(f.root, full_name, atom, arr.shape)
            ds[:] = arr

def load_sparse_mat(name, store='store.h5'):
    with tb.openFile(store) as f:
        pars = []
        for par in ('data', 'indices', 'indptr', 'shape'):
            pars.append(getattr(f.root, '%s_%s' % (name, par)).read())
    m = sparse.csr_matrix(tuple(pars[:3]), shape=pars[3])
    return m

It is trivial to adapt it to csc matrices. 将其适应csc矩阵是微不足道的。

A CSR matrix can be fully reconstructed from its data , indices and indptr attributes. CSR矩阵可以从其dataindicesindptr属性中完全重建。 These are just regular numpy arrays, so there should be no problem storing them as 3 separate arrays in pytables, then passing them back to the constructor of csr_matrix . 这些只是常规的numpy数组,因此将它们存储为pytables中的3个独立数组,然后将它们传递回csr_matrix的构造函数应该没有问题。 See the scipy docs . 请参阅scipy文档

Edit: Pietro's answer has pointed out that the shape member should also be stored 编辑: Pietro的回答指出, shape成员也应该存储

I have updated Pietro Battiston 's excellent answer for Python 3.6 and PyTables 3.x, as some PyTables function names have changed in the upgrade from 2.x. 我已经更新了Pietro Battiston对Python 3.6和PyTables 3.x的出色答案,因为一些PyTables函数名称在2.x升级中发生了变化。

import numpy as np
from scipy import sparse
import tables

def store_sparse_mat(M, name, filename='store.h5'):
    """
    Store a csr matrix in HDF5

    Parameters
    ----------
    M : scipy.sparse.csr.csr_matrix
        sparse matrix to be stored

    name: str
        node prefix in HDF5 hierarchy

    filename: str
        HDF5 filename
    """
    assert(M.__class__ == sparse.csr.csr_matrix), 'M must be a csr matrix'
    with tables.open_file(filename, 'a') as f:
        for attribute in ('data', 'indices', 'indptr', 'shape'):
            full_name = f'{name}_{attribute}'

            # remove existing nodes
            try:  
                n = getattr(f.root, full_name)
                n._f_remove()
            except AttributeError:
                pass

            # add nodes
            arr = np.array(getattr(M, attribute))
            atom = tables.Atom.from_dtype(arr.dtype)
            ds = f.create_carray(f.root, full_name, atom, arr.shape)
            ds[:] = arr

def load_sparse_mat(name, filename='store.h5'):
    """
    Load a csr matrix from HDF5

    Parameters
    ----------
    name: str
        node prefix in HDF5 hierarchy

    filename: str
        HDF5 filename

    Returns
    ----------
    M : scipy.sparse.csr.csr_matrix
        loaded sparse matrix
    """
    with tables.open_file(filename) as f:

        # get nodes
        attributes = []
        for attribute in ('data', 'indices', 'indptr', 'shape'):
            attributes.append(getattr(f.root, f'{name}_{attribute}').read())

    # construct sparse matrix
    M = sparse.csr_matrix(tuple(attributes[:3]), shape=attributes[3])
    return M

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM