[英]Custom trained Object Detection model, can't read predicted TensorflowJS tensors array elements in ReactJS app
我使用这篇文章https://blog.tensorflow.org/2021/01/custom-object-detection-in-browser.html ,这个笔记本https: //colab.research.google.com/drive/1MdzgmdYJk947sXyls45V7auMPttHelBZ?usp=sharing并使用此 raactJS 应用程序模板https://github.com/hugozanini/TFJS-object-detection ,但无法弄清楚 tensorflowJS 模型的预测张量数组元素在 JS 的反应应用程序中的视频窗口上绘制正方形,得到以下问题。
https://colab.research.google.com/drive/1MdzgmdYJk947sXyls45V7auMPttHelBZ?usp=sharing
我无法理解和读取 TensorflowJS 模型的数组输出,因此无法使用它来显示结果。
在下面函数的预测部分,如何知道哪个预测张量数组元素是用于框、分数或类别的?
renderPredictions = predictions => {
const ctx = this.canvasRef.current.getContext("2d");
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
// Font options.
const font = "16px sans-serif";
ctx.font = font;
ctx.textBaseline = "top";
//Getting predictions
//HERE how to know which prediction element is for boxes, scores or classes?
const boxes = predictions[4].arraySync();
const scores = predictions[5].arraySync();
const classes = predictions[6].dataSync();
const detections = this.buildDetectedObjects(scores, threshold,
boxes, classes, classesDir);
detections.forEach(item => {
const x = item['bbox'][0];
const y = item['bbox'][1];
const width = item['bbox'][2];
const height = item['bbox'][3];
// Draw the bounding box.
ctx.strokeStyle = "#00FFFF";
ctx.lineWidth = 4;
ctx.strokeRect(x, y, width, height);
// Draw the label background.
ctx.fillStyle = "#00FFFF";
const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
const textHeight = parseInt(font, 10); // base 10
ctx.fillRect(x, y, textWidth + 4, textHeight + 4);
});
预测变量的输出 - 所有根元素及其子元素展开:
predictions.forEach(i=>{console.log(i.arraySync())})
Element 0:
[100]
0: 100
length: 1
[[Prototype]]: Array(0)
Element 1:
[Array(100)]
0: Array(100)
0: 1878
1: 1851
2: 1021
3: 1520
4: 1015
5: 1214
6: 1576
7: 973
.
.
.
99: 164
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
Element 2:
[Array(100)]
0: Array(100)
0: 0.32437506318092346
1: 0.301730751991272
2: 0.29111218452453613
3: 0.28852424025535583
4: 0.270997017621994
5: 0.2643387019634247
6: 0.26040562987327576
7: 0.26039886474609375
8: 0.25926679372787476
.
.
.
99: 0.07848591357469559
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
Element 3:
[Array(1917)]
0: Array(1917)
[0 … 99]
[100 … 199]
[200 … 299]
[300 … 399]
[400 … 499]
[500 … 599]
[600 … 699]
[700 … 799]
[800 … 899]
[900 … 999]
[1000 … 1099]
[1100 … 1199]
[1200 … 1299]
[1300 … 1399]
[1400 … 1499]
[1500 … 1599]
[1600 … 1699]
[1700 … 1799]
[1800 … 1899]
[1900 … 1916]
length: 1917
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
Element 4:
[Array(1917)]
0: Array(1917)
[0 … 99]
[100 … 199]
[200 … 299]
[300 … 399]
[400 … 499]
[500 … 599]
[600 … 699]
[700 … 799]
[800 … 899]
[900 … 999]
[1000 … 1099]
[1100 … 1199]
[1200 … 1299]
[1300 … 1399]
[1400 … 1499]
[1500 … 1599]
[1600 … 1699]
[1700 … 1799]
[1800 … 1899]
[1900 … 1916]
length: 1917
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
Element 5:
[Array(100)]
0: Array(100)
0: (4) [0.6603810787200928, 0.05048862099647522, 1, 0.9434524774551392]
1: (4) [0.23983880877494812, 0, 0.8556625843048096, 0.4294479489326477]
2: (4) [0.8396192789077759, 0.7738925218582153, 1, 1]
3: (4) [0.5717893242835999, 0, 0.9534353613853455, 0.4774531126022339]
4: (4) [0.8307931423187256, 0.6522707939147949, 1, 0.9869933128356934]
5: (4) [0.05821444094181061, 0, 0.4731786847114563, 0.372313916683197]
6: (4) [0.7183740139007568, 0, 0.9927113056182861, 0.49393266439437866]
7: (4) [0.8234801888465881, 0, 1, 0.22610855102539062]
8: (4) [0.8305549621582031, 0.022644445300102234, 1, 0.363572359085083]
9: (4) [0.830806314945221, 0.12703362107276917, 1, 0.46965447068214417]
10: (4) [0.8307640552520752, 0.5480412244796753, 1, 0.8905272483825684]
11: (4) [0.8307292461395264, 0.2323085516691208, 1, 0.5748147368431091]
12: (4) [0.8307291269302368, 0.44283488392829895, 1, 0.7853409051895142]
.
.
.
99: (4) [0.02437518537044525, 0.7758948802947998, 0.2996327877044678, 0.9475241899490356]
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
Element 6:
[Array(100)]
0: Array(100)
0: (2) [0.008746352046728134, 0.32437506318092346]
1: (2) [0.008580721914768219, 0.301730751991272]
2: (2) [0.0071436576545238495, 0.29111218452453613]
3: (2) [0.00598490796983242, 0.28852424025535583]
4: (2) [0.0071937208995223045, 0.270997017621994]
5: (2) [0.005746915470808744, 0.2643387019634247]
6: (2) [0.005973200313746929, 0.26040562987327576]
7: (2) [0.007103783078491688, 0.26039886474609375]
8: (2) [0.007148009724915028, 0.25926679372787476]
.
.
.
99: (2) [0.005199342034757137, 0.07848591357469559]
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
Element 7:
[Array(100)]
0: Array(100)
0: 1
1: 1
2: 1
3: 1
4: 1
5: 1
6: 1
7: 1
8: 1
9: 1
.
.
.
99: 1
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)
完整的js代码如下
import React from "react";
import ReactDOM from "react-dom";
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
import "./styles.css";
tf.setBackend('webgl');
const threshold = 0.75;
async function load_model() {
// It's possible to load the model locally or from a repo
// You can choose whatever IP and PORT you want in the "http://127.0.0.1:8080/model.json" just set it before in your https server
const model = await loadGraphModel("http://127.0.0.1:8080/model.json");
//const model = await loadGraphModel("https://raw.githubusercontent.com/hugozanini/TFJS-object-detection/master/models/kangaroo-detector/model.json");
return model;
}
let classesDir = {
0: {
name: 'connection',
id: 0
}
}
class App extends React.Component {
videoRef = React.createRef();
canvasRef = React.createRef();
componentDidMount() {
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
const webCamPromise = navigator.mediaDevices
.getUserMedia({
audio: false,
video: {
facingMode: "user"
}
})
.then(stream => {
window.stream = stream;
this.videoRef.current.srcObject = stream;
return new Promise((resolve, reject) => {
this.videoRef.current.onloadedmetadata = () => {
resolve();
};
});
});
const modelPromise = load_model();
Promise.all([modelPromise, webCamPromise])
.then(values => {
this.detectFrame(this.videoRef.current, values[0]);
})
.catch(error => {
console.error(error);
});
}
}
detectFrame = (video, model) => {
tf.engine().startScope();
model.executeAsync(this.process_input(video)).then(predictions => {
this.renderPredictions(predictions, video);
requestAnimationFrame(() => {
this.detectFrame(video, model);
});
tf.engine().endScope();
});
};
process_input(video_frame){
const tfimg = tf.browser.fromPixels(video_frame).toInt();
const expandedimg = tfimg.transpose([0,1,2]).expandDims();
return expandedimg;
};
buildDetectedObjects(scores, threshold, boxes, classes, classesDir) {
const detectionObjects = []
var video_frame = document.getElementById('frame');
scores.forEach((score, i) => {
if (score > threshold) {
const bbox = [];
const minY = boxes[0][i][0] * video_frame.offsetHeight;
const minX = boxes[0][i][1] * video_frame.offsetWidth;
const maxY = boxes[0][i][2] * video_frame.offsetHeight;
const maxX = boxes[0][i][3] * video_frame.offsetWidth;
bbox[0] = minX;
bbox[1] = minY;
bbox[2] = maxX - minX;
bbox[3] = maxY - minY;
detectionObjects.push({
class: classes[i],
label: classesDir[classes[i]].name,
score: score.toFixed(4),
bbox: bbox
})
}
})
return detectionObjects
}
renderPredictions = predictions => {
const ctx = this.canvasRef.current.getContext("2d");
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
// Font options.
const font = "16px sans-serif";
ctx.font = font;
ctx.textBaseline = "top";
//Getting predictions
// HERE how to know which prediction element is for boxes, scores or classes?
const boxes = predictions[4].arraySync();
const scores = predictions[5].arraySync();
const classes = predictions[6].dataSync();
const detections = this.buildDetectedObjects(scores, threshold,
boxes, classes, classesDir);
detections.forEach(item => {
const x = item['bbox'][0];
const y = item['bbox'][1];
const width = item['bbox'][2];
const height = item['bbox'][3];
// Draw the bounding box.
ctx.strokeStyle = "#00FFFF";
ctx.lineWidth = 4;
ctx.strokeRect(x, y, width, height);
// Draw the label background.
ctx.fillStyle = "#00FFFF";
const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
const textHeight = parseInt(font, 10); // base 10
ctx.fillRect(x, y, textWidth + 4, textHeight + 4);
});
detections.forEach(item => {
const x = item['bbox'][0];
const y = item['bbox'][1];
// Draw the text last to ensure it's on top.
ctx.fillStyle = "#000000";
ctx.fillText(item["label"] + " " + (100*item["score"]).toFixed(2) + "%", x, y);
});
};
render() {
return (
<div>
<h1>Real-Time Object Detection</h1>
<h3>MobileNetV2</h3>
<video
style={{height: '600px', width: "500px"}}
className="size"
autoPlay
playsInline
muted
ref={this.videoRef}
width="600"
height="500"
id="frame"
/>
<canvas
className="size"
ref={this.canvasRef}
width="600"
height="500"
/>
</div>
);
}
}
const rootElement = document.getElementById("root");
ReactDOM.render(<App />, rootElement);
谢谢
这取决于您如何标记数据,但通常情况下,它类似于一个数组,包含在每个元素中:
x_center,y_center,height,width,class,score
或者
x0,y0,xf,yf,class,score
(being x0,y0 the left upper corner and xf,yf right bottom corners of bounding box)
在任何情况下,您都必须乘以:
(x0,xf,x_center etc..)*runtime_image_width
(y0,yf,y_center etc..)*runtime_image_height
因为预测位置是相对于图像的比例而不是绝对像素值
长话短说:您将不得不测试不同的组合或假设它将遵循完全标记的格式。 但是您给出的示例非常简单:
const boxes = predictions[4].arraySync();
const scores = predictions[5].arraySync();
const classes = predictions[6].dataSync();
@guinther-kovalski 我通过查看张量的形状和大小以及查看显示屏上绘制的检测方块来弄清楚:
const boxes = predictions[2].arraySync();
const scores = predictions[7].arraySync();
const classes = predictions[1].dataSync();
但是如何仅通过读取这些输出张量来弄清楚。 无论如何,我们可以使这些张量成为人类可读的吗?
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.