![](/img/trans.png)
[英]How to create a CNN cross filter for Tensorflow in python?
[英]How to create multiple instances of a learning model in Tensorflow, CNN using python?
where is a class with Constructor as can be found here .我有一個 class 可以實例化一個 object 其中是一個 class 與構造函數可以在這里找到。 實際上我想創建多個模型並將一個 model 的某些部分分配給新創建的 model 並丟棄舊的 model。 我在下面收到錯誤。 任何幫助將不勝感激。 我真的被困住了。 我在這里也出於類似目的問了一個問題。 因此,歡迎這兩個答案。
def train(args):
train_data, val_data = load_data(args.input)
train_data = prepare_data(train_data)
val_data = prepare_data(val_data)
with tf.variable_scope("", reuse=True) as scope:
et = EyeTracker()
train_loss_history, train_err_history, val_loss_history, val_err_history = et.train(train_data, val_data, \
lr=args.learning_rate, \
batch_size=args.batch_size, \
max_epoch=args.max_epoch, \
min_delta=1e-4, \
patience=args.patience, \
print_per_epoch=args.print_per_epoch,
out_model=args.save_model)
save some parts of the (et)
scope.reuse_variables()
et = EyeTracker()
Assign some parts of previous (et) to the new one and continue training
train_loss_history, train_err_history, val_loss_history, val_err_history = et.train(train_data, val_data, \
lr=args.learning_rate, \
batch_size=args.batch_size, \
max_epoch=args.max_epoch, \
min_delta=1e-4, \
patience=args.patience, \
print_per_epoch=args.print_per_epoch,
out_model=args.save_model)
錯誤是
變量 conv1_eye_w 不存在,或者不是使用 tf.get_variable() 創建的。 您的意思是在 VarScope 中設置 reuse=tf.AUTO_REUSE 嗎? 如果我的問題很煩人,我真的很抱歉。
部分解決。 我將默認構造函數 init 更改為成員 function initialize()
並將值作為 arguments 傳遞,如下所示。
g = tf.Graph()
with g.as_default():
et = EyeTracker()
et.initialize(96,256,384,64,96,256,384,64)
result_temp = et.train(n_epoch, train_data, val_data, lr=args.learning_rate, batch_size=args.batch_size, max_epoch=args.max_epoch, min_delta=1e-4, patience=args.patience, print_per_epoch=args.print_per_epoch, out_model=args.save_model)
g = tf.Graph()
with g.as_default():
et = EyeTracker()
et.initialize(80,256,384,64,96,256,384,64)
result_temp = et.train(n_epoch, train_data, val_data, lr=args.learning_rate, batch_size=args.batch_size, max_epoch=args.max_epoch, min_delta=1e-4, patience=args.patience, print_per_epoch=args.print_per_epoch, out_model=args.save_model)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.