I am trying to create an n-by-m matrix of 0s and 1s with a very simple structure:
[[1 0 0 0 0 0 0 ...],
[1 1 0 0 0 0 0 ...],
[1 1 1 0 0 0 0 ...],
[1 1 1 1 0 0 0 ...],
[0 1 1 1 1 0 0 ...],
[0 1 1 1 1 1 0 ...],
...
[... 0 0 0 1 1 1 1],
[... 0 0 0 0 1 1 1],
[... 0 0 0 0 0 1 1],
[... 0 0 0 0 0 0 1]]
However, I don't want to start writing loops as this is probably achievable using something built in: A = tf.constant(???,shape(n,m))
Note that after the first 3 rows there is simply a repetition of four 1s, followed by m-3 0s, until the last 3 rows.
So I am thinking something along the lines of a repeat of repeat, but I have no idea what syntax to use.
You're looking for tf.matrix_band_part()
. As per the manual, it's function is to
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
So in your case you'd create a matrix with ones, and then take a 4-wide band like this:
tf.matrix_band_part( tf.ones( shape = ( 1, n, m ) ), 3, 0 )
Tested code:
import tensorflow as tf
x = tf.ones( shape = ( 1, 9, 6 ) )
y = tf.matrix_band_part( x, 3, 0 )
with tf.Session() as sess:
res = sess.run( y )
print ( res )
Output:
[[[1. 0. 0. 0. 0. 0.]
[1. 1. 0. 0. 0. 0.]
[1. 1. 1. 0. 0. 0.]
[1. 1. 1. 1. 0. 0.]
[0. 1. 1. 1. 1. 0.]
[0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]]]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.