[英]Tensorflow/ffjs: Embedding Layer Weights are NaN after training with model.fit() method
我是 tensorflow (tfjs) 编程的新手,我想为任务和相关操作训练嵌入。 这个想法是有 2 个嵌入层(一个用于任务,一个用于相应的操作)。 我感兴趣的是获得两个嵌入层的训练权重。 但是,在执行 model.fit(...) 方法后,嵌入层的所有权重都是NaN 。 我不知道,我做错了什么。
使用以下代码片段,我创建了适当的 tensorflow 网络:
function createNetwork(samples, embeddingDimSize, taskIndexesLength, actionIndexesLength) {
const actionInput = tf.input({name: 'actionInput', shape: [1]})
const actionEmbedding = tf.layers.embedding({name: 'actionEmbedding', inputDim: actionIndexesLength, outputDim: embeddingDimSize}).apply(actionInput)
const taskInput = tf.input({name: 'taskInput', shape: [1]})
const taskEmbedding = tf.layers.embedding({name: 'taskEmbedding', inputDim: taskIndexesLength, outputDim: embeddingDimSize}).apply(taskInput)
const dotLayer = tf.layers.dot({normalize: true, axes: 2}).apply([taskEmbedding, actionEmbedding])
const reshapeLayer = tf.layers.reshape({targetShape: [1]}).apply(dotLayer)
const output = tf.layers.dense({name: "output", activation: "sigmoid", units: 1}).apply(reshapeLayer)
const model = tf.model({name: "myEmbeddings", inputs: [taskInput, actionInput], outputs: output})
model.compile({optimizer: tf.train.adam(0.0001), loss: 'binaryCrossentropy', metrics: ['accuracy']})
return model}
因此,网络的拓扑顺序连接如下:
我为输入层 1 和输入层 2 生成训练示例,并生成带有相应标签(0 或 1)的 output
然后我尝试用训练数据拟合 model
await model.fit([x1, x2], y, { epochs: epochs, batchSize: 50 }
之后,我通过以下方式请求嵌入层的权重:
model.getLayer('taskEmbedding').getWeights()[0].print()
model.getLayer('actionEmbedding').getWeights()[0].print()
执行所有内容时我没有收到任何错误。 任何想法,我的代码出了什么问题以及我必须改变什么?
编辑:经过一番调查,我现在发现了我的代码中的错误。 词嵌入维度索引必须匹配嵌入层索引长度。 在我的例子中,我从 1 而不是 0 开始单词词汇索引。因此,我必须将 +1 添加到嵌入层定义的inputDim变量。 另一种方法是从 0 开始您的单词词汇索引,然后您不能将 +1 添加到 inputDim 变量。 如果您的单词词汇索引从 1 开始,则需要以下示例代码。单词词汇表意味着在该上下文中,您的单词被索引以使它们能够被网络唯一地处理:
const actionEmbedding = tf.layers.embedding({name: 'actionEmbedding', inputDim: actionIndexesLength + 1, outputDim: embeddingDimSize}).apply(actionInput)
和
const taskEmbedding = tf.layers.embedding({name: 'taskEmbedding', inputDim: taskIndexesLength + 1, outputDim: embeddingDimSize}).apply(taskInput)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.