簡體   English   中英

如何使用 python 在 Tensorflow、CNN 中創建學習 model 的多個實例?

[英]How to create multiple instances of a learning model in Tensorflow, CNN using python?

where is a class with Constructor as can be found here .我有一個 class 可以實例化一個 object 其中是一個 class 與構造函數可以在這里找到。 實際上我想創建多個模型並將一個 model 的某些部分分配給新創建的 model 並丟棄舊的 model。 我在下面收到錯誤。 任何幫助將不勝感激。 我真的被困住了。 我在這里也出於類似目的問了一個問題。 因此,歡迎這兩個答案。

def train(args):
    train_data, val_data = load_data(args.input)
    train_data = prepare_data(train_data)
    val_data = prepare_data(val_data)
    with tf.variable_scope("", reuse=True) as scope:
        et = EyeTracker()
        train_loss_history, train_err_history, val_loss_history, val_err_history = et.train(train_data, val_data, \
                                            lr=args.learning_rate, \
                                            batch_size=args.batch_size, \
                                            max_epoch=args.max_epoch, \
                                            min_delta=1e-4, \
                                            patience=args.patience, \
                                            print_per_epoch=args.print_per_epoch,
                                            out_model=args.save_model)
        save some parts of the (et)
        scope.reuse_variables()
        et = EyeTracker()
        Assign some parts of previous (et) to the new one and continue training
        train_loss_history, train_err_history, val_loss_history, val_err_history = et.train(train_data, val_data, \
                                            lr=args.learning_rate, \
                                            batch_size=args.batch_size, \
                                            max_epoch=args.max_epoch, \
                                            min_delta=1e-4, \
                                            patience=args.patience, \
                                            print_per_epoch=args.print_per_epoch,
                                            out_model=args.save_model)

錯誤是

變量 conv1_eye_w 不存在,或者不是使用 tf.get_variable() 創建的。 您的意思是在 VarScope 中設置 reuse=tf.AUTO_REUSE 嗎? 如果我的問題很煩人,我真的很抱歉。

部分解決。 我將默認構造函數 init 更改為成員 function initialize()並將值作為 arguments 傳遞,如下所示。

g = tf.Graph()
with g.as_default(): 
            et = EyeTracker()
            et.initialize(96,256,384,64,96,256,384,64)            
            result_temp = et.train(n_epoch, train_data, val_data, lr=args.learning_rate, batch_size=args.batch_size, max_epoch=args.max_epoch, min_delta=1e-4, patience=args.patience, print_per_epoch=args.print_per_epoch, out_model=args.save_model)

g = tf.Graph()
with g.as_default(): 
            et = EyeTracker()
            et.initialize(80,256,384,64,96,256,384,64)           
            result_temp = et.train(n_epoch, train_data, val_data, lr=args.learning_rate, batch_size=args.batch_size, max_epoch=args.max_epoch, min_delta=1e-4, patience=args.patience, print_per_epoch=args.print_per_epoch, out_model=args.save_model)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM