簡體   English   中英

從 tensorflow js model 獲取原始 output

[英]Get raw output from tensorflow js model

我正在通過一個簡單的卷積網絡處理 MNIST 數據集。 我想通過過程中的每個步驟手動告訴程序 go,即迭代每個 Epoch 和批次,而不使用 model.fit (我想更好地了解 Z2C39BC19B5461AC36FEDCZD146 的內部工作原理)2。

目前我有:

let EPOCHS = 10;    
let batches = Math.floor(TRAIN_DATA_SIZE/BATCH_SIZE);

      for (let i = 0; i < EPOCHS; i++){
        for (let j = 0; i < batches; i++){
          let inputs = await getNextTrain(BATCH_SIZE,TRAIN_DATA_SIZE, data);

          let inputXs = inputs[0];
          let inputYs = inputs[1];

          let output = await model.evaluate(inputXs,inputYs);

遍歷每個 Epoch 和每個批次以進行 traian。 但是 model.evalutate() 不會從網絡返回 output 值,只返回定義的損失/指標。 是否可以為此添加特定指標以返回網絡的輸出

我想要的是代表網絡認為 output 用於每個輸入的 10 元素數組(或張量)(使用 MNIST 所以希望概率網絡認為每個 output 是數字 0-9)

您可以在內部循環中調用 model.fit 並設置以下參數epoch=1 initial_epoch=i \\ So that it trains only the current epoch and batch. 'i' being outer loop variable. x=inputXs y=inputYs epoch=1 initial_epoch=i \\ So that it trains only the current epoch and batch. 'i' being outer loop variable. x=inputXs y=inputYs

這將相應地更新 model 中的所有權重。 然后,您可以調用 model.evaluate 或 model.predict 或 model.get_layer 來獲取您想要查看的信息。 因為,您現在擁有每一層信息,您可以通過僅根據需要評估特定層來獨立檢查它們的 output 值(參考您提到的行以查看所有概率......等等)

隨着更新的變化:

let EPOCHS = 10;
let batches = Math.floor(TRAIN_DATA_SIZE/BATCH_SIZE);

  for (let i = 0; i < EPOCHS; i++){
    for (let j = 0; i < batches; i++){
      let inputs = await getNextTrain(BATCH_SIZE,TRAIN_DATA_SIZE, data);

      let inputXs = inputs[0];
      let inputYs = inputs[1];
      model.fit(..on above mentioned params..);

      model.evaluate() // To get loss and metric values.
      model.predict()  // To the final output of the model for input samples given.
      model.get_layer() //To get info about a particular layer and then retrieve the required info. Refer https://keras.io/layers/about-keras-layers/

      // For example to know the 1st layer's output : model.layers[0].output;

`

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM