简体   繁体   中英

How to convert a Numpy array of lower diagonal elements to an array of full matrices

I have an array consisting of the lower diagonal elements of a single matrix. I can convert that to the full matrix by following the methods from How to convert triangle matrix to square in NumPy? . For a single matrix, the example looks like:

# Create the lower diagonal elements of a 6x6 matrix.
ld = np.arange(21)

# Create full 6x6 matrix
x = np.zeros((6,6))

# Stuff lower triangular values into it
x[np.tril_indices(6)] = ld

# Populate upper triangular elements
x = x + x.T

# Fix diagonals (they got doubled)
diag_idx = [0, 2, 5, 9, 14, 20]
np.fill_diagonal(x, ld[diag_idx])

print(x)

and we get the expected full matrix

[[ 0.  1.  3.  6. 10. 15.]
 [ 1.  2.  4.  7. 11. 16.]
 [ 3.  4.  5.  8. 12. 17.]
 [ 6.  7.  8.  9. 13. 18.]
 [10. 11. 12. 13. 14. 19.]
 [15. 16. 17. 18. 19. 20.]]

Now I want to extend this to having an array of N sets of lower diagonal elements and want to get back an array of N full matrices. The former has shape (N, 21) and the latter (N, 6, 6). I expanding the single matrix example into one containing 2 matrices

# Two sets of lower diagonal elements
ld = np.arange(2*21).reshape(2, 21)

# Two sets of full 6x6 matrices
x = np.zeros((ld.shape[0], 6, 6))

# Find the lower triangular indices of each row and stuff them with the 
# values from the corresponding row in the lower diagonal array
x[:, np.tril_indices(6)] = ld[:]

# Populate upper triangular elements
x[:] = x[:] + x[:].T

# Fix diagonals (they got doubled)
diag_idx = [0, 2, 5, 9, 14, 20]
np.fill_diagonal(x[:], ld[:][diag_idx])

but I get a shape mismatch on the line x[:, np.tril_indices(6)] = ld[:]

ValueError: shape mismatch: value array of shape (2,21) could not be broadcast to indexing result of shape (2,2,21,6)

I could do a normal Python loop over the N sets of lower diagonal values, but was trying to do it all via Numpy. Any suggestions on where I've gone wrong with my indexing?

The expected values in X are:

[[[ 0.  1.  3.  6. 10. 15.]
  [ 1.  2.  4.  7. 11. 16.]
  [ 3.  4.  5.  8. 12. 17.]
  [ 6.  7.  8.  9. 13. 18.]
  [10. 11. 12. 13. 14. 19.]
  [15. 16. 17. 18. 19. 20.]],
 [[21., 22., 24., 27., 31., 36.],
  [22., 23., 25., 28., 32., 37.],
  [24., 25., 26., 29., 33., 38.],
  [27., 28., 29., 30., 34., 39.],
  [31., 32., 33., 34., 35., 40.],
  [36., 37., 38., 39., 40., 41.]]]

You can do like this, works for any N .

import numpy as np

N = 3
ld = np.arange(N*21).reshape(N, 21)
x = np.zeros((ld.shape[0], 6, 6))
tril_ind = np.tril_indices(6)
x[:, tril_ind[0], tril_ind[1]] = ld
x += np.transpose(x, (0, 2, 1))
diag_ind = np.diag_indices(6)
x[:, diag_ind[0], diag_ind[1]] /= 2
print(x)

This prints

[[[ 0.  1.  3.  6. 10. 15.]
  [ 1.  2.  4.  7. 11. 16.]
  [ 3.  4.  5.  8. 12. 17.]
  [ 6.  7.  8.  9. 13. 18.]
  [10. 11. 12. 13. 14. 19.]
  [15. 16. 17. 18. 19. 20.]]

 [[21. 22. 24. 27. 31. 36.]
  [22. 23. 25. 28. 32. 37.]
  [24. 25. 26. 29. 33. 38.]
  [27. 28. 29. 30. 34. 39.]
  [31. 32. 33. 34. 35. 40.]
  [36. 37. 38. 39. 40. 41.]]

 [[42. 43. 45. 48. 52. 57.]
  [43. 44. 46. 49. 53. 58.]
  [45. 46. 47. 50. 54. 59.]
  [48. 49. 50. 51. 55. 60.]
  [52. 53. 54. 55. 56. 61.]
  [57. 58. 59. 60. 61. 62.]]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM