简体   繁体   中英

scipy sparse matrix: remove the rows whose all elements are zero

I have a sparse matrix which is transformed from sklearn tfidfVectorier. I believe that some rows are all-zero rows. I want to remove them. However, as far as I know, the existing built-in functions, eg nonzero() and eliminate_zero(), focus on zero entries, rather than rows.

Is there any easy way to remove all-zero rows from a sparse matrix?

Example: What I have now (actually in sparse format):

[ [0, 0, 0]
  [1, 0, 2]
  [0, 0, 1] ]

What I want to get:

[ [1, 0, 2]
  [0, 0, 1] ]

Slicing + getnnz() does the trick:

M = M[M.getnnz(1)>0]

Works directly on csr_array . You can also remove all 0 columns without changing formats:

M = M[:,M.getnnz(0)>0]

However if you want to remove both you need

M = M[M.getnnz(1)>0][:,M.getnnz(0)>0] #GOOD

I am not sure why but

M = M[M.getnnz(1)>0, M.getnnz(0)>0] #BAD

does not work.

There aren't existing functions for this, but it's not too bad to write your own:

def remove_zero_rows(M):
  M = scipy.sparse.csr_matrix(M)

First, convert the matrix to CSR (compressed sparse row) format. This is important because CSR matrices store their data as a triple of (data, indices, indptr) , where data holds the nonzero values, indices stores column indices, and indptr holds row index information. The docs explain better:

the column indices for row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in data[indptr[i]:indptr[i+1]] .

So, to find rows without any nonzero values, we can just look at successive values of M.indptr . Continuing our function from above:

  num_nonzeros = np.diff(M.indptr)
  return M[num_nonzeros != 0]

The second benefit of CSR format here is that it's relatively cheap to slice rows, which simplifies the creation of the resulting matrix.

Thanks for your reply, @perimosocordiae

I just find another solution by myself. I am posting here in case someone may need it in the future.

def remove_zero_rows(X)
    # X is a scipy sparse matrix. We want to remove all zero rows from it
    nonzero_row_indice, _ = X.nonzero()
    unique_nonzero_indice = numpy.unique(nonzero_row_indice)
    return X[unique_nonzero_indice]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM