[英]convert tf.dense Tensor to tf.one_hot Tensor on Graph execution Tensorflow
TF版本:2.11
我嘗試使用 TFRecords tf.data 管道訓練一個簡單的 2 輸入分類器
我無法將僅包含標量的tf.dense 張量轉換為tf.onehot 向量
# get all recorddatasets abspath
training_names= [record_path+'/'+rec for rec in os.listdir(record_path) if rec.startswith('train')]
# load in tf dataset
train_dataset = tf.data.TFRecordDataset(training_names[1])
train_dataset = train_dataset.map(return_xy)
映射 function:
def return_xy(example_proto):
#parse example
sample= parse_function(example_proto)
#decode image 1
encoded_image1 = sample['image/encoded_1']
decoded_image1 = decode_image(encoded_image1)
#decode image 2
encoded_image2 = sample['image/encoded_2']
decoded_image2 = decode_image(encoded_image2)
#decode label
print(f'image/object/class/'+level: {sample['image/object/class/'+level]}')
class_label = tf.sparse.to_dense(sample['image/object/class/'+level])
print(f'type of class label :{type(class_label)}')
print(class_label)
# conversion to onehot with depth 26 :: -> how can i extract only the value or convert directly to tf.onehot??
label_onehot=tf.one_hot(class_label,26)
#resizing image
input_left=tf.image.resize(decoded_image1,[416, 416])
input_right=tf.image.resize(decoded_image2,[416, 416])
return {'input_3res1':input_left, 'input_5res2':input_right} , label_onehot
output:
image/object/class/'+level: SparseTensor(indices=Tensor("ParseSingleExample/ParseExample/ParseExampleV2:14", shape=(None, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseExample/ParseExampleV2:31", shape=(None,), dtype=int64), dense_shape=Tensor("ParseSingleExample/ParseExample/ParseExampleV2:48", shape=(1,), dtype=int64))
type of class label :<class 'tensorflow.python.framework.ops.Tensor'>
Tensor("SparseToDense:0", shape=(None,), dtype=int64)
但是我確信 label 在這個張量中,因為當急切地運行它時
raw_dataset = tf.data.TFRecordDataset([rec_file])
parsed_dataset = raw_dataset.map(parse_function) # only parsing
for sample in parsed_dataset:
class_label=tf.sparse.to_dense(sample['image/object/class/label_level3'])[0]
print(f'type of class label :{type(class_label)}')
print(f'labels from labelmap :{class_label}')
我得到 output:
type of class label :<class 'tensorflow.python.framework.ops.EagerTensor'>
labels from labelmap :7
如果我只是為 label 選擇一個隨機數並將其傳遞給 tf_one_hot(randint, 26) 然后 model 開始訓練(顯然是無意義的)。
所以問題是我如何轉換:
張量(“SparseToDense:0”,shape=(None,),dtype=int64)
到一個
張量("one_hot:0", shape=(26,), dtype=float32)
到目前為止我嘗試了什么
在調用 data.map(parse_xy) 時,我嘗試在 tf 張量上調用 numpy() 但沒有用,這只適用於急切的張量。
據我所知,我不能使用急切執行,因為 parse_xy function 中的所有內容都在整個圖表上被執行:我已經嘗試啟用急切執行 -> 失敗
https://www.tensorflow.org/api_docs/python/tf/config/run_functions_eagerly
Note: This flag has no effect on functions passed into tf.data transformations as arguments.
tf.data functions are never executed eagerly and are always executed as a compiled Tensorflow Graph.
我也嘗試過使用 tf_pyfunc 但這只會返回另一個形狀未知的 tf.Tensor
def get_onehot(tensor):
class_label=tensor[0]
return tf.one_hot(class_label,26)
並在 parse_xy 中添加以下行:
label_onehot=tf.py_function(func=get_onehot, inp=[class_label], Tout=tf.int64)
但我總是得到一個未知的形狀,它不能僅僅改變 with.set_shape()
更新:僅使用 TF 函數即可解決問題
tf.gather 允許索引一個 TF.tensor
class_label_gather = tf.sparse.to_dense(sample['image/object/class/'+level])
class_indices = tf.gather(tf.cast(class_label_gather,dtype=tf.int32),0)
label_onehot=tf.one_hot(class_indices,26)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.