簡體   English   中英

如何使用tf.train.Checkpoint在Tensorflow 2.0中保存和加載選定變量以及所有變量?

[英]How to save and load selected and all variables in tensorflow 2.0 using tf.train.Checkpoint?

如何將選定的變量保存在下面顯示的tensorflow 2.0中的文件中,然后使用tf.train.Checkpoint將它們加載到另一個代碼中的某些已定義變量中?

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def saveVariables(self):
        # how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element

另外,請展示如何保存所有變量並使用tf.train.Checkpoint重新加載。 提前致謝。

在下面的代碼中,我將一個名為變量的數組保存到具有您選擇的名稱的.txt文件中。 該文件將與python文件位於同一文件夾中。 打開函數中的“ wb”表示以截斷形式寫入(因此刪除文件中以前的所有內容)並使用字節格式。 我用pickle處理保存/解析列表。

import pickle

    def saveVariables(self, variables): #where 'variables' is a list of variables
        with open("nameOfYourFile.txt", 'wb+') as file:
           pickle.dump(variables, file)

    def retrieveVariables(self, filename):
        variables = []
        with open(str(filename), 'rb') as file:
            variables = pickle.load(file)
        return variables

要將特定內容保存到文件中,只需將其添加為saveVariables中的variables參數即可,如下所示:

myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)

要從文本文件中檢索具有特定名稱的變量:

myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]

您可以在類中的任何位置添加這些功能。

為了能夠訪問示例中的initList / moreList,您應該從它們的函數中返回它們(就像我對variables列表所做的那樣)或將它們設為全局。

我不確定這是否是您的意思,但是您可以專門為要保存和還原的變量創建一個tf.train.Checkpoint對象。 請參見以下示例:

import tensorflow as tf

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()
        self.ckpt = self.makeCheckpoint()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def makeCheckpoint(self):
        return tf.train.Checkpoint(
            init3=self.initList[3], init55=self.initList[55],
            init60=self.initList[60], more4=self.moreList[4])

    def saveVariables(self):
        self.ckpt.save('./ckpt')

    def restoreVariables(self):
        status = self.ckpt.restore(tf.train.latest_checkpoint('.'))
        status.assert_consumed()  # Optional check

# Create variables
v1 = manyVariables()
# Assigned fixed values
for i, v in enumerate(v1.initList):
    v.assign(i * tf.ones_like(v))
for i, v in enumerate(v1.moreList):
    v.assign(100 + i * tf.ones_like(v))
# Save them
v1.saveVariables()

# Create new variables
v2 = manyVariables()
# Check initial values
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[-3.3441594  -0.18425298 -0.4898144  -1.2330629   0.08798431]
#  [ 1.5002227   0.99475247  0.7817361   0.3849587  -0.59548247]
#  [-0.57121766 -1.277224    0.6957546  -0.67618763  0.0510064 ]
#  [ 0.85491985  0.13310803 -0.93152267  0.10205163  0.57520276]
#  [-1.0606447  -0.16966362 -1.0448577   0.56799036 -0.90726566]]

# Restore them
v2.restoreVariables()
# Check values after restoring
print(v2.initList[2].numpy())
# [[-1.9110833   0.05956204 -1.1753829  -0.3572553  -0.95049495]
#  [ 0.31409055  1.1262076   0.47890127 -0.1699607   0.4409122 ]
#  [-0.75385517 -0.13847834  0.97012395  0.42515194 -1.4371008 ]
#  [ 0.44205236  0.86158335  0.6919655  -2.5156968   0.16496429]
#  [-1.241602   -0.15177743  0.5603795  -0.3560254  -0.18536267]]
print(v2.initList[3].numpy())
# [[3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]
#  [3. 3. 3. 3. 3.]]

如果要將所有變量保存在列表中,則可以將makeCheckpoint替換為以下內容:

def makeCheckpoint(self):
    return tf.train.Checkpoint(
        **{f'init{i}': v for i, v in enumerate(self.initList)},
        **{f'more{i}': v for i, v in enumerate(self.moreList)})

請注意,您可以具有“嵌套”檢查點,因此,更一般而言,您可以具有為變量列表創建檢查點的函數,例如:

def listCheckpoint(varList):
    # Use 'item{}'.format(i) if using Python <3.6
    return tf.train.Checkpoint(**{f'item{i}': v for i, v in enumerate(varList)})

然后您可以擁有:

def makeCheckpoint(self):
    return tf.train.Checkpoint(init=listCheckpoint(self.initList),
                               more=listCheckpoint(self.moreList))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM