簡體   English   中英

numpy 數組用於輸入到 tensorflow/keras 神經網絡的 dtype 是否重要?

[英]Does it matter what dtype a numpy array is for input into a tensorflow/keras neural network?

If I take a tensorflow.keras model and call model.fit(x, y) (where x and y are numpy arrays) does it matter what dtype the numpy array is? 我最好只是使dtype盡可能小(例如二進制數據的int8 )還是這會給 tensorflow/keras 額外的工作以將其轉換為浮點數?

您應該將輸入轉換為np.float32 ,這是 Keras 的默認 dtype。 查一下:

import tensorflow as tf
tf.keras.backend.floatx()
'float32'

如果你在 np.float64 中輸入np.float64 ,它會抱怨:

import tensorflow as tf
from tensorflow.keras.layers import Dense 
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)

X = iris[:, :3]
y = iris[:, 3]

ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.d0 = Dense(16, activation='relu')
    self.d1 = Dense(32, activation='relu')
    self.d2 = Dense(1, activation='linear')

  def call(self, x):
    x = self.d0(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

model = MyModel()

_ = model(X)

WARNING:tensorflow:Layer my_model 正在將輸入張量從 dtype float64 轉換為 layer 的 dtype 的 float32,這是 TensorFlow 2 中的新行為。圖層具有 dtype float32 因為它的 dtype 默認為 floatx。 如果您打算在 float32 中運行此層,則可以放心地忽略此警告。 If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2. To change all layers to have dtype float64 by default, call tf.keras.backend.set_floatx('float64') . 要僅更改此圖層,請將 dtype='float64' 傳遞給圖層構造函數。 如果您是該層的作者,您可以通過將 autocast=False 傳遞給基礎層構造函數來禁用自動轉換。

可以使用 Tensorflow 進行8 位輸入的訓練,這稱為量化。 但在大多數情況下,這是具有挑戰性且不必要的(即,除非您需要在邊緣設備上部署模型)。

tl;博士將您的輸入保存在np.float32中。 另見這篇文章

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM