简体   繁体   English

是否可以在运行时通过模型对象访问已移植到 TensorFlow.js 的预训练 Tensorflow 模型的权重?

[英]Are the weights of a pre-trained Tensorflow model that has been ported to TensorFlow.js accessible at runtime through the model object?

How do I access the weights of a model during debug?如何在调试期间访问模型的权重?

When I inspect model.model.weights['dense_3/bias'][0] during execution in the debugger the actual weights aren't present.当我在调试器中执行期间检查model.model.weights['dense_3/bias'][0]时,实际权重不存在。 However when I console.log the expression the weights are printed.但是,当我console.log ,会打印出权重的表达式。 It seems like there is some sort of deferred execution going on?似乎正在进行某种延迟执行?

I have created a snippet below that is based on the toxic classifier medium article that shows how to access the weights object for a specific layer.我在下面创建了一个基于有毒分类器媒体文章的片段,该文章展示了如何访问特定层的权重对象。

 const threshold = 0.9; // Which toxicity labels to return. const labelsToInclude = ['identity_attack', 'insult', 'threat']; toxicity.load(threshold, labelsToInclude).then(model => { // Now you can use the `model` object to label sentences. model.classify(['you suck']).then(predictions => { console.log("Specific weights: "+ model.model.weights['dense_3/bias'][0]) document.getElementById("predictions").innerHTML = JSON.stringify(predictions, null, 2); }); });
 <!DOCTYPE html> <html lang="en-us"> <head> <meta charset="UTF-8"> <title>Activity 1: Basic HTML Bio</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity"></script> </head> <body> <div id="predictions"> Will be populated by prebuilt toxicity model </div> </body> </html>

Each layer in an array of tensors.张量数组中的每一层。 The weights of the layer can be accessed by iterating over the array.可以通过迭代数组来访问层的权重。

const t = model.model.weights['dense_3/bias'][0] // t is a tensor
t.print() // will display the tensor in the console
// to add value to the weight
t.add(tf.scalar(0.5))

console.log(model.model.weights['dense_3/bias'][0]) will display an object and not the value of the tensor. console.log(model.model.weights['dense_3/bias'][0]) 将显示一个对象而不是张量的值。 The reason being that a tensor is a class in TypeScript which is transpiled in js as an object of type Function .原因是张量是 TypeScript 中的一个类,它在 js 中被转译为Function类型的对象。 That is the reason why console.log(model.model.weights['dense_3/bias'][0]) will print an object with keys being the attributes of the class tensor.这就是为什么console.log(model.model.weights['dense_3/bias'][0])将打印一个对象,其键是类张量的属性。 One needs to invoke the print method to see the underlying values of the tensor需要调用print方法来查看张量的底层值

 const threshold = 0.9; // Which toxicity labels to return. const labelsToInclude = ['identity_attack', 'insult', 'threat']; toxicity.load(threshold, labelsToInclude).then(model => { // print weights model.model.weights['dense_3/bias'][0].print() // continue processing });
 <!DOCTYPE html> <html lang="en-us"> <head> <meta charset="UTF-8"> <title>Activity 1: Basic HTML Bio</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity"></script> </head> <body> </body> </html>

If you want to get the tensor value on cpu and display it using innerHTML of the dom element, you can consider using data or dataSync如果想获取cpu上的tensor值并使用dom元素的innerHTML显示,可以考虑使用data或者dataSync

From another post I was able to derive how to access the weights.从另一篇文章中,我能够推导出如何访问权重。

For each layer there is a data promise will give access to the weights.对于每一层,都有一个data承诺可以访问权重。

 const threshold = 0.9; // Which toxicity labels to return. const labelsToInclude = ['identity_attack', 'insult', 'threat']; toxicity.load(threshold, labelsToInclude).then(model => { // Now you can use the `model` object to label sentences. model.classify(['you suck']).then(predictions => { model.model.weights['dense_3/bias'][0].data().then( function(value) { document.getElementById("specific_weights").innerHTML = JSON.stringify(value); }); document.getElementById("predictions").innerHTML = JSON.stringify(predictions, null, 2); }); });
 <!DOCTYPE html> <html lang="en-us"> <head> <meta charset="UTF-8"> <title>Activity 1: Basic HTML Bio</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity"></script> </head> <body> <div id="predictions"> Will be populated by prebuilt toxicity model </div> <div id="specific_weights"> Will contain weights for specific layer </div> </body> </html>

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM