简体   繁体   English

运行线性回归模型时,Tensorflow.js 返回“NaN”值

[英]Tensorflow.js returns “NaN” Value when running Linear Regression Model

I'm trying to run this linear regression model which would essentially give me an output based on const prediction = model.predict((tf.tensor2d([20], [1,1])));我正在尝试运行这个线性回归模型,它基本上会给我一个基于const prediction = model.predict((tf.tensor2d([20], [1,1]))); I'm however unfortunately getting NaN Value everytime I run the code to receive a prediction.然而不幸的是,每次我运行代码以接收预测时都会得到 NaN 值。

What's the best way to approach a solution?寻求解决方案的最佳方法是什么? Are there other methods?还有其他方法吗?

Thank you!谢谢!

Below is the code:下面是代码:

 async function learnLinear() { const fontSize = document.getElementById("count").innerHTML; const parsed = parseInt(fontSize); const model = tf.sequential(); model.add(tf.layers.dense({ units: 1, inputShape: [1] })); const learningRate = 0.0001; const optimizer = tf.train.sgd(learningRate); model.compile({ loss: "meanSquaredError", optimizer: "sgd", }); const xs = tf.tensor2d( [ 54, 20, 22, 34, 18, 47, 28, 54, 36, 51, 44, 31, 39, 19, 45, 48, 32, 27, 25, 54, 27, 38, 25, 38, 57, 49, 28, 19, 59, 28, 27, 55, 60, 49, 40, 45, 35, 45, 39, 25, 50, 58, 28, 59, 21, 37, 47, 31, 46, 18, ], [50, 1] ); const ys = tf.tensor2d( [ 14, 15, 15, 15, 16, 17, 15, 16, 15, 17, 17, 15, 16, 15, 15, 16, 17, 17, 17, 14, 16, 15, 15, 16, 17, 15, 16, 14, 15, 16, 14, 17, 15, 14, 14, 17, 15, 14, 14, 16, 16, 14, 14, 17, 17, 14, 17, 14, 14, 17, ], [50, 1] ); await model.fit(xs, ys, { epochs: 500 }); const prediction = model.predict(tf.tensor2d([20], [1, 1])); const value = prediction.dataSync()[0]; console.log("Prediction", value); }

You forgot to specify what metric the model is supposed to track.您忘记指定模型应该跟踪的指标。

const batchSize = 32;
const epochs = 500;

model.compile({
  loss: "meanSquaredError",
  optimizer: "sgd",
  metrics: ["mse"],
});

await model.fit(xs, ys, batchSize, epochs);

const prediction = model.predict(tf.tensor2d([20], [1, 1]));

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM