繁体   English   中英

自定义训练的对象检测模型,无法在 ReactJS 应用程序中读取预测的 TensorflowJS 张量数组元素

[英]Custom trained Object Detection model, can't read predicted TensorflowJS tensors array elements in ReactJS app

我使用这篇文章https://blog.tensorflow.org/2021/01/custom-object-detection-in-browser.html ,这个笔记本https: //colab.research.google.com/drive/1MdzgmdYJk947sXyls45V7auMPttHelBZ?usp=sharing并使用此 raactJS 应用程序模板https://github.com/hugozanini/TFJS-object-detection ,但无法弄清楚 tensorflowJS 模型的预测张量数组元素在 JS 的反应应用程序中的视频窗口上绘制正方形,得到以下问题。


我无法理解和读取 TensorflowJS 模型的数组输出,因此无法使用它来显示结果。


renderPredictions = predictions => {
    const ctx = this.canvasRef.current.getContext("2d");
    ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

    // Font options.
    const font = "16px sans-serif";
    ctx.font = font;
    ctx.textBaseline = "top";

    //Getting predictions
    //HERE how to know which prediction element is for boxes, scores or classes?
    const boxes = predictions[4].arraySync();
    const scores = predictions[5].arraySync();
    const classes = predictions[6].dataSync();
    const detections = this.buildDetectedObjects(scores, threshold,
                                    boxes, classes, classesDir);

    detections.forEach(item => {
      const x = item['bbox'][0];
      const y = item['bbox'][1];
      const width = item['bbox'][2];
      const height = item['bbox'][3];

      // Draw the bounding box.
      ctx.strokeStyle = "#00FFFF";
      ctx.lineWidth = 4;
      ctx.strokeRect(x, y, width, height);

      // Draw the label background.
      ctx.fillStyle = "#00FFFF";
      const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
      const textHeight = parseInt(font, 10); // base 10
      ctx.fillRect(x, y, textWidth + 4, textHeight + 4);

预测变量的输出 - 所有根元素及其子元素展开:


Element 0:
0: 100
length: 1
[[Prototype]]: Array(0)

Element 1:
0: Array(100)
0: 1878
1: 1851
2: 1021
3: 1520
4: 1015
5: 1214
6: 1576
7: 973
99: 164
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 2:
0: Array(100)
0: 0.32437506318092346
1: 0.301730751991272
2: 0.29111218452453613
3: 0.28852424025535583
4: 0.270997017621994
5: 0.2643387019634247
6: 0.26040562987327576
7: 0.26039886474609375
8: 0.25926679372787476
99: 0.07848591357469559
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 3:
0: Array(1917)
[0 … 99]
[100 … 199]
[200 … 299]
[300 … 399]
[400 … 499]
[500 … 599]
[600 … 699]
[700 … 799]
[800 … 899]
[900 … 999]
[1000 … 1099]
[1100 … 1199]
[1200 … 1299]
[1300 … 1399]
[1400 … 1499]
[1500 … 1599]
[1600 … 1699]
[1700 … 1799]
[1800 … 1899]
[1900 … 1916]
length: 1917
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 4:
0: Array(1917)
[0 … 99]
[100 … 199]
[200 … 299]
[300 … 399]
[400 … 499]
[500 … 599]
[600 … 699]
[700 … 799]
[800 … 899]
[900 … 999]
[1000 … 1099]
[1100 … 1199]
[1200 … 1299]
[1300 … 1399]
[1400 … 1499]
[1500 … 1599]
[1600 … 1699]
[1700 … 1799]
[1800 … 1899]
[1900 … 1916]
length: 1917
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 5:
0: Array(100)
0: (4) [0.6603810787200928, 0.05048862099647522, 1, 0.9434524774551392]
1: (4) [0.23983880877494812, 0, 0.8556625843048096, 0.4294479489326477]
2: (4) [0.8396192789077759, 0.7738925218582153, 1, 1]
3: (4) [0.5717893242835999, 0, 0.9534353613853455, 0.4774531126022339]
4: (4) [0.8307931423187256, 0.6522707939147949, 1, 0.9869933128356934]
5: (4) [0.05821444094181061, 0, 0.4731786847114563, 0.372313916683197]
6: (4) [0.7183740139007568, 0, 0.9927113056182861, 0.49393266439437866]
7: (4) [0.8234801888465881, 0, 1, 0.22610855102539062]
8: (4) [0.8305549621582031, 0.022644445300102234, 1, 0.363572359085083]
9: (4) [0.830806314945221, 0.12703362107276917, 1, 0.46965447068214417]
10: (4) [0.8307640552520752, 0.5480412244796753, 1, 0.8905272483825684]
11: (4) [0.8307292461395264, 0.2323085516691208, 1, 0.5748147368431091]
12: (4) [0.8307291269302368, 0.44283488392829895, 1, 0.7853409051895142]
99: (4) [0.02437518537044525, 0.7758948802947998, 0.2996327877044678, 0.9475241899490356]
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 6:
0: Array(100)
0: (2) [0.008746352046728134, 0.32437506318092346]
1: (2) [0.008580721914768219, 0.301730751991272]
2: (2) [0.0071436576545238495, 0.29111218452453613]
3: (2) [0.00598490796983242, 0.28852424025535583]
4: (2) [0.0071937208995223045, 0.270997017621994]
5: (2) [0.005746915470808744, 0.2643387019634247]
6: (2) [0.005973200313746929, 0.26040562987327576]
7: (2) [0.007103783078491688, 0.26039886474609375]
8: (2) [0.007148009724915028, 0.25926679372787476]
99: (2) [0.005199342034757137, 0.07848591357469559]
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 7:
0: Array(100)
0: 1
1: 1
2: 1
3: 1
4: 1
5: 1
6: 1
7: 1
8: 1
9: 1
99: 1
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)


import React from "react";
import ReactDOM from "react-dom";
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
import "./styles.css";

const threshold = 0.75;

async function load_model() {
    // It's possible to load the model locally or from a repo
    // You can choose whatever IP and PORT you want in the "" just set it before in your https server
    const model = await loadGraphModel("");
    //const model = await loadGraphModel("https://raw.githubusercontent.com/hugozanini/TFJS-object-detection/master/models/kangaroo-detector/model.json");
    return model;

let classesDir = {
    0: {
        name: 'connection',
        id: 0

class App extends React.Component {
  videoRef = React.createRef();
  canvasRef = React.createRef();

  componentDidMount() {
    if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
      const webCamPromise = navigator.mediaDevices
          audio: false,
          video: {
            facingMode: "user"
        .then(stream => {
          window.stream = stream;
          this.videoRef.current.srcObject = stream;
          return new Promise((resolve, reject) => {
            this.videoRef.current.onloadedmetadata = () => {

      const modelPromise = load_model();

      Promise.all([modelPromise, webCamPromise])
        .then(values => {
          this.detectFrame(this.videoRef.current, values[0]);
        .catch(error => {

    detectFrame = (video, model) => {
        model.executeAsync(this.process_input(video)).then(predictions => {
        this.renderPredictions(predictions, video);
        requestAnimationFrame(() => {
          this.detectFrame(video, model);

    const tfimg = tf.browser.fromPixels(video_frame).toInt();
    const expandedimg = tfimg.transpose([0,1,2]).expandDims();
    return expandedimg;

  buildDetectedObjects(scores, threshold, boxes, classes, classesDir) {
    const detectionObjects = []
    var video_frame = document.getElementById('frame');

    scores.forEach((score, i) => {
      if (score > threshold) {
        const bbox = [];
        const minY = boxes[0][i][0] * video_frame.offsetHeight;
        const minX = boxes[0][i][1] * video_frame.offsetWidth;
        const maxY = boxes[0][i][2] * video_frame.offsetHeight;
        const maxX = boxes[0][i][3] * video_frame.offsetWidth;
        bbox[0] = minX;
        bbox[1] = minY;
        bbox[2] = maxX - minX;
        bbox[3] = maxY - minY;
          class: classes[i],
          label: classesDir[classes[i]].name,
          score: score.toFixed(4),
          bbox: bbox
    return detectionObjects

  renderPredictions = predictions => {
    const ctx = this.canvasRef.current.getContext("2d");
    ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

    // Font options.
    const font = "16px sans-serif";
    ctx.font = font;
    ctx.textBaseline = "top";

    //Getting predictions
    // HERE how to know which prediction element is for boxes, scores or classes?
    const boxes = predictions[4].arraySync();
    const scores = predictions[5].arraySync();
    const classes = predictions[6].dataSync();
    const detections = this.buildDetectedObjects(scores, threshold,
                                    boxes, classes, classesDir);

    detections.forEach(item => {
      const x = item['bbox'][0];
      const y = item['bbox'][1];
      const width = item['bbox'][2];
      const height = item['bbox'][3];

      // Draw the bounding box.
      ctx.strokeStyle = "#00FFFF";
      ctx.lineWidth = 4;
      ctx.strokeRect(x, y, width, height);

      // Draw the label background.
      ctx.fillStyle = "#00FFFF";
      const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
      const textHeight = parseInt(font, 10); // base 10
      ctx.fillRect(x, y, textWidth + 4, textHeight + 4);

    detections.forEach(item => {
      const x = item['bbox'][0];
      const y = item['bbox'][1];

      // Draw the text last to ensure it's on top.
      ctx.fillStyle = "#000000";
      ctx.fillText(item["label"] + " " + (100*item["score"]).toFixed(2) + "%", x, y);

  render() {
    return (
        <h1>Real-Time Object Detection</h1>
          style={{height: '600px', width: "500px"}}

const rootElement = document.getElementById("root");
ReactDOM.render(<App />, rootElement);





(being x0,y0 the left upper corner and xf,yf right bottom corners of bounding box)


(x0,xf,x_center etc..)*runtime_image_width
(y0,yf,y_center etc..)*runtime_image_height


长话短说:您将不得不测试不同的组合或假设它将遵循完全标记的格式。 但是您给出的示例非常简单:

const boxes = predictions[4].arraySync();
const scores = predictions[5].arraySync();
const classes = predictions[6].dataSync(); 

@guinther-kovalski 我通过查看张量的形状和大小以及查看显示屏上绘制的检测方块来弄清楚:

const boxes = predictions[2].arraySync();
const scores = predictions[7].arraySync();
const classes = predictions[1].dataSync();      

但是如何仅通过读取这些输出张量来弄清楚。 无论如何,我们可以使这些张量成为人类可读的吗?


声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM