One possible implementation could be this:
class LocallyDenseLayer(tf.keras.layers.Layer):
def __init__(self, k, m, *args, **kwargs):
super().__init__(args, kwargs)
# alternatively, you can move that setup in the build method
# and infer the shape from the input
# this is left as an exercise to the reader
self.w = self.add_weight(name="weight", shape=(k,m))
def call(self, inputs):
# assuming input has shape [batch, k, m]
dotp = tf.linalg.diag_part(tf.tensordot(inputs, self.w, axes=[[1],[0]]))
return tf.nn.relu(dotp)
Using tf.tensordot
to do the dot product over the dimension k
and extracting only the diagonal, that contains what we want.
A simple example of usage:
X = tf.random.normal((100,5,1024))
y = tf.random.normal((100,1))
model = tf.keras.Sequential(
[
tf.keras.Input((5,1024)),
LocallyDenseLayer(5,1024),
tf.keras.layers.Dense(512, activation="relu"),
tf.keras.layers.Dense(1, activation="sigmoid")
]
)
model.compile(loss="mse",optimizer="sgd")
model.fit(X,y)
Okay, I first seriously misinterpreted your question, sorry for that. But if I understand correctly, you want to use
keras.layers.LocallyConnected1D with kernel_size=1 and dataformat='channels_first'
This would give you a different kernel for every (batch_size, k, 1) tensor.
I can solve the problem in an easier way. Suppose we have a 3D dataset with dimensions of [10000, 5, 1024]. I have first reshaped it to a matrix in which that for the inputs between 0 to 10000 in the original dataset, I have placed the 5 elements in the columns of the matrix [5, 1024] together, using the following code:
mat = np.reshape(mat,(mat.shape[0],mat.shape[1]*mat.shape[2]),'F')
So the mat
matrix has [10000, 5120] dimensions. Finally I have used the LocallyConnected1D
predefined layer in tensorflow.keras:
model.add(tf.keras.layers.LocallyConnected1D(filters=1, kernel_size=5, strides=5, data_format='channels_first', input_shape=(1,5120), activation='relu'))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.