简体   繁体   中英

How to run BatchNorm on a ragged tensor in TF 2.x?

I'm trying to run BatchNormalization on a batch of ragged tensors in TF 2.x, but seem to run into errors while doing this. (I can do a conversion to and from ragged tensors before and after the BatchNorm forward call, but I'm unable to run a to_tensor() in NonEager mode, which is a requisite for me to train the network efficiently).

Pytorch has a BatchNorm1D, but TF does not seem to have any such API, any suggestions/pointers would be helpful.

What about creating a custom layer on keras just to convert your ragged tensor to tensor and then you inject that in the BatchNormalization Layer.

I actually tried that. Here is what I did and there is an issue but I did not have time to fix it. It might help you but perhaps not.

In the code below, I'm creating a super simple "to_tensor" layer that I could use before normalization.

It sort of work but since i'm creating a new tensor in the to.tensor() line, tf can't find any trainable variables anymore.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

def to_tensor(x):
    return x.to_tensor(shape=(None, 4))

X = tf.ragged.constant(
    [[3, 1, 4, 1], [], [5, 9, 2], [],[6]])

y = tf.random.normal(shape=(5,1))

inp = Input(shape=(None,), ragged=True)
x = Lambda(to_tensor)(inp)
out = Dense(1)(x)

m = Model(inp,out)

m.compile(optimizer='adam',metrics=['accuracy'])
history = m.fit(X, y, epochs=10)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM