繁体   English   中英

自定义训练的对象检测模型,无法在 ReactJS 应用程序中读取预测的 TensorflowJS 张量数组元素

[英]Custom trained Object Detection model, can't read predicted TensorflowJS tensors array elements in ReactJS app

我使用这篇文章https://blog.tensorflow.org/2021/01/custom-object-detection-in-browser.html ,这个笔记本https: //colab.research.google.com/drive/1MdzgmdYJk947sXyls45V7auMPttHelBZ?usp=sharing并使用此 raactJS 应用程序模板https://github.com/hugozanini/TFJS-object-detection ,但无法弄清楚 tensorflowJS 模型的预测张量数组元素在 JS 的反应应用程序中的视频窗口上绘制正方形,得到以下问题。

https://colab.research.google.com/drive/1MdzgmdYJk947sXyls45V7auMPttHelBZ?usp=sharing

我无法理解和读取 TensorflowJS 模型的数组输出,因此无法使用它来显示结果。

在下面函数的预测部分,如何知道哪个预测张量数组元素是用于框、分数或类别的?

renderPredictions = predictions => {
    const ctx = this.canvasRef.current.getContext("2d");
    ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

    // Font options.
    const font = "16px sans-serif";
    ctx.font = font;
    ctx.textBaseline = "top";

    //Getting predictions
    //HERE how to know which prediction element is for boxes, scores or classes?
    const boxes = predictions[4].arraySync();
    const scores = predictions[5].arraySync();
    const classes = predictions[6].dataSync();
    const detections = this.buildDetectedObjects(scores, threshold,
                                    boxes, classes, classesDir);

    detections.forEach(item => {
      const x = item['bbox'][0];
      const y = item['bbox'][1];
      const width = item['bbox'][2];
      const height = item['bbox'][3];

      // Draw the bounding box.
      ctx.strokeStyle = "#00FFFF";
      ctx.lineWidth = 4;
      ctx.strokeRect(x, y, width, height);

      // Draw the label background.
      ctx.fillStyle = "#00FFFF";
      const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
      const textHeight = parseInt(font, 10); // base 10
      ctx.fillRect(x, y, textWidth + 4, textHeight + 4);
    });

预测变量的输出 - 所有根元素及其子元素展开:

predictions.forEach(i=>{console.log(i.arraySync())})

Element 0:
[100]
0: 100
length: 1
[[Prototype]]: Array(0)

Element 1:
[Array(100)]
0: Array(100)
0: 1878
1: 1851
2: 1021
3: 1520
4: 1015
5: 1214
6: 1576
7: 973
.
.
.
99: 164
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 2:
[Array(100)]
0: Array(100)
0: 0.32437506318092346
1: 0.301730751991272
2: 0.29111218452453613
3: 0.28852424025535583
4: 0.270997017621994
5: 0.2643387019634247
6: 0.26040562987327576
7: 0.26039886474609375
8: 0.25926679372787476
.
.
.
99: 0.07848591357469559
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 3:
[Array(1917)]
0: Array(1917)
[0 … 99]
[100 … 199]
[200 … 299]
[300 … 399]
[400 … 499]
[500 … 599]
[600 … 699]
[700 … 799]
[800 … 899]
[900 … 999]
[1000 … 1099]
[1100 … 1199]
[1200 … 1299]
[1300 … 1399]
[1400 … 1499]
[1500 … 1599]
[1600 … 1699]
[1700 … 1799]
[1800 … 1899]
[1900 … 1916]
length: 1917
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 4:
[Array(1917)]
0: Array(1917)
[0 … 99]
[100 … 199]
[200 … 299]
[300 … 399]
[400 … 499]
[500 … 599]
[600 … 699]
[700 … 799]
[800 … 899]
[900 … 999]
[1000 … 1099]
[1100 … 1199]
[1200 … 1299]
[1300 … 1399]
[1400 … 1499]
[1500 … 1599]
[1600 … 1699]
[1700 … 1799]
[1800 … 1899]
[1900 … 1916]
length: 1917
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 5:
[Array(100)]
0: Array(100)
0: (4) [0.6603810787200928, 0.05048862099647522, 1, 0.9434524774551392]
1: (4) [0.23983880877494812, 0, 0.8556625843048096, 0.4294479489326477]
2: (4) [0.8396192789077759, 0.7738925218582153, 1, 1]
3: (4) [0.5717893242835999, 0, 0.9534353613853455, 0.4774531126022339]
4: (4) [0.8307931423187256, 0.6522707939147949, 1, 0.9869933128356934]
5: (4) [0.05821444094181061, 0, 0.4731786847114563, 0.372313916683197]
6: (4) [0.7183740139007568, 0, 0.9927113056182861, 0.49393266439437866]
7: (4) [0.8234801888465881, 0, 1, 0.22610855102539062]
8: (4) [0.8305549621582031, 0.022644445300102234, 1, 0.363572359085083]
9: (4) [0.830806314945221, 0.12703362107276917, 1, 0.46965447068214417]
10: (4) [0.8307640552520752, 0.5480412244796753, 1, 0.8905272483825684]
11: (4) [0.8307292461395264, 0.2323085516691208, 1, 0.5748147368431091]
12: (4) [0.8307291269302368, 0.44283488392829895, 1, 0.7853409051895142]
.
.
.
99: (4) [0.02437518537044525, 0.7758948802947998, 0.2996327877044678, 0.9475241899490356]
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 6:
[Array(100)]
0: Array(100)
0: (2) [0.008746352046728134, 0.32437506318092346]
1: (2) [0.008580721914768219, 0.301730751991272]
2: (2) [0.0071436576545238495, 0.29111218452453613]
3: (2) [0.00598490796983242, 0.28852424025535583]
4: (2) [0.0071937208995223045, 0.270997017621994]
5: (2) [0.005746915470808744, 0.2643387019634247]
6: (2) [0.005973200313746929, 0.26040562987327576]
7: (2) [0.007103783078491688, 0.26039886474609375]
8: (2) [0.007148009724915028, 0.25926679372787476]
.
.
.
99: (2) [0.005199342034757137, 0.07848591357469559]
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

Element 7:
[Array(100)]
0: Array(100)
0: 1
1: 1
2: 1
3: 1
4: 1
5: 1
6: 1
7: 1
8: 1
9: 1
.
.
.
99: 1
length: 100
[[Prototype]]: Array(0)
length: 1
[[Prototype]]: Array(0)

完整的js代码如下

import React from "react";
import ReactDOM from "react-dom";
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
import "./styles.css";
tf.setBackend('webgl');

const threshold = 0.75;

async function load_model() {
    // It's possible to load the model locally or from a repo
    // You can choose whatever IP and PORT you want in the "http://127.0.0.1:8080/model.json" just set it before in your https server
    const model = await loadGraphModel("http://127.0.0.1:8080/model.json");
    //const model = await loadGraphModel("https://raw.githubusercontent.com/hugozanini/TFJS-object-detection/master/models/kangaroo-detector/model.json");
    return model;
  }

let classesDir = {
    0: {
        name: 'connection',
        id: 0
    }
}

class App extends React.Component {
  videoRef = React.createRef();
  canvasRef = React.createRef();


  componentDidMount() {
    if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
      const webCamPromise = navigator.mediaDevices
        .getUserMedia({
          audio: false,
          video: {
            facingMode: "user"
          }
        })
        .then(stream => {
          window.stream = stream;
          this.videoRef.current.srcObject = stream;
          return new Promise((resolve, reject) => {
            this.videoRef.current.onloadedmetadata = () => {
              resolve();
            };
          });
        });

      const modelPromise = load_model();

      Promise.all([modelPromise, webCamPromise])
        .then(values => {
          this.detectFrame(this.videoRef.current, values[0]);
        })
        .catch(error => {
          console.error(error);
        });
    }
  }

    detectFrame = (video, model) => {
        tf.engine().startScope();
        model.executeAsync(this.process_input(video)).then(predictions => {
        this.renderPredictions(predictions, video);
        requestAnimationFrame(() => {
          this.detectFrame(video, model);
        });
        tf.engine().endScope();
      });
  };

  process_input(video_frame){
    const tfimg = tf.browser.fromPixels(video_frame).toInt();
    const expandedimg = tfimg.transpose([0,1,2]).expandDims();
    return expandedimg;
  };

  buildDetectedObjects(scores, threshold, boxes, classes, classesDir) {
    const detectionObjects = []
    var video_frame = document.getElementById('frame');

    scores.forEach((score, i) => {
      if (score > threshold) {
        const bbox = [];
        const minY = boxes[0][i][0] * video_frame.offsetHeight;
        const minX = boxes[0][i][1] * video_frame.offsetWidth;
        const maxY = boxes[0][i][2] * video_frame.offsetHeight;
        const maxX = boxes[0][i][3] * video_frame.offsetWidth;
        bbox[0] = minX;
        bbox[1] = minY;
        bbox[2] = maxX - minX;
        bbox[3] = maxY - minY;
        detectionObjects.push({
          class: classes[i],
          label: classesDir[classes[i]].name,
          score: score.toFixed(4),
          bbox: bbox
        })
      }
    })
    return detectionObjects
  }

  renderPredictions = predictions => {
    const ctx = this.canvasRef.current.getContext("2d");
    ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

    // Font options.
    const font = "16px sans-serif";
    ctx.font = font;
    ctx.textBaseline = "top";

    //Getting predictions
    // HERE how to know which prediction element is for boxes, scores or classes?
    const boxes = predictions[4].arraySync();
    const scores = predictions[5].arraySync();
    const classes = predictions[6].dataSync();
    const detections = this.buildDetectedObjects(scores, threshold,
                                    boxes, classes, classesDir);

    detections.forEach(item => {
      const x = item['bbox'][0];
      const y = item['bbox'][1];
      const width = item['bbox'][2];
      const height = item['bbox'][3];

      // Draw the bounding box.
      ctx.strokeStyle = "#00FFFF";
      ctx.lineWidth = 4;
      ctx.strokeRect(x, y, width, height);

      // Draw the label background.
      ctx.fillStyle = "#00FFFF";
      const textWidth = ctx.measureText(item["label"] + " " + (100 * item["score"]).toFixed(2) + "%").width;
      const textHeight = parseInt(font, 10); // base 10
      ctx.fillRect(x, y, textWidth + 4, textHeight + 4);
    });

    detections.forEach(item => {
      const x = item['bbox'][0];
      const y = item['bbox'][1];

      // Draw the text last to ensure it's on top.
      ctx.fillStyle = "#000000";
      ctx.fillText(item["label"] + " " + (100*item["score"]).toFixed(2) + "%", x, y);
    });
  };

  render() {
    return (
      <div>
        <h1>Real-Time Object Detection</h1>
        <h3>MobileNetV2</h3>
        <video
          style={{height: '600px', width: "500px"}}
          className="size"
          autoPlay
          playsInline
          muted
          ref={this.videoRef}
          width="600"
          height="500"
          id="frame"
        />
        <canvas
          className="size"
          ref={this.canvasRef}
          width="600"
          height="500"
        />
      </div>
    );
  }
}

const rootElement = document.getElementById("root");
ReactDOM.render(<App />, rootElement);

谢谢

这取决于您如何标记数据,但通常情况下,它类似于一个数组,包含在每个元素中:

x_center,y_center,height,width,class,score

或者

x0,y0,xf,yf,class,score 
(being x0,y0 the left upper corner and xf,yf right bottom corners of bounding box)

在任何情况下,您都必须乘以:

(x0,xf,x_center etc..)*runtime_image_width
(y0,yf,y_center etc..)*runtime_image_height

因为预测位置是相对于图像的比例而不是绝对像素值

长话短说:您将不得不测试不同的组合或假设它将遵循完全标记的格式。 但是您给出的示例非常简单:

const boxes = predictions[4].arraySync();
const scores = predictions[5].arraySync();
const classes = predictions[6].dataSync(); 

@guinther-kovalski 我通过查看张量的形状和大小以及查看显示屏上绘制的检测方块来弄清楚:

const boxes = predictions[2].arraySync();
const scores = predictions[7].arraySync();
const classes = predictions[1].dataSync();      

但是如何仅通过读取这些输出张量来弄清楚。 无论如何,我们可以使这些张量成为人类可读的吗?

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM