簡體   English   中英

在tensorflow中添加新的op-形狀函數

[英]Adding new op in tensorflow - Shape functions

我正在嘗試在Tensorflow中添加一個新操作,其中有兩個輸入,即3D張量和一個常數,輸出4D張量。 通過將3D張量復制常數定義的次數,可以得到4D張量。 shape函數通過以下方式實現:

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
    ::tensorflow::shape_inference::ShapeHandle output;
    ::tensorflow::shape_inference::ShapeHandle out1 = c->Vector(::tensorflow::shape_inference::DimensionOrConstant(5));
    TF_RETURN_IF_ERROR(c->Concatenate(c->input(0),out1,&output));
    c->set_output(0,output);
    return Status::OK();
})
.Doc(R"doc(
     Replicating the 3D input tensor in a 4D tensor.
)doc");

我想將第四維的大小(由代碼中的out1定義)設置為第二個輸入(即常量值)。 怎么做?

也許MakeShapeFromShapeTensor是您要找的東西? 就像是:

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
    ::tensorflow::shape_inference::ShapeHandle n;
    TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &n));
    ::tensorflow::shape_inference::ShapeHandle out;
    TF_RETURN_IF_ERROR(c->Concatenate(n, c->input(0), &out));
    c->set_output(0, out);
    return Status::OK();
})

也就是說,您可能知道這一點,但是請確保: TensorFlow中的逐元素算術運算支持broadcast ,因此至少在這種情況下,您不需要此自定義操作。

對於其他情況,您還可以結合使用tf.tiletf.shapetf.concattf.reshape來達到相同的效果。 例如,以下通過重復向量創建矩陣:

import tensorflow as tf
oneD = tf.constant([1,2])
n = tf.constant([5])
twoD = tf.reshape(tf.tile(oneD, n), tf.concat([n, tf.shape(oneD)], 0))

with tf.Session() as sess:
  print oneD.eval()
  print twoD.eval()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM