[英]Train neural-net with a sequence ( currently not converging )
由於具有遞歸性質,因此我能夠通過一次輸入一項來激活一個只有一個輸入神經元的lstm,並具有一個序列。
但是,當我嘗試使用相同的技術訓練網絡時,它永遠不會收斂。 培訓將永遠持續下去。
這是我正在做的事情,我正在將自然語言字符串轉換為二進制,然后一次輸入一位數字。 之所以將其轉換為二進制,是因為網絡僅接受0到1之間的值。
我知道訓練的工作原理,因為當我使用與輸入神經元一樣多的值進行訓練時,在這種情況下為1:[0],它收斂並且訓練良好。
我想我可以分別傳遞每個數字,但是對於每個數字,它將有一個單獨的理想輸出。 當數字在另一個訓練集中以另一個理想輸出再次出現時,它不會收斂,因為例如0怎么可能屬於0和1類? 請告訴我在這個假設上我是否錯了。
如何使用序列訓練該lstm,以便在激活時對相似的序列進行相似的分類?
這是我的整個培訓師文件:https://github.com/theirf/synaptic/blob/master/src/trainer.js
這是在工作人員上訓練網絡的代碼:
workerTrain: function(set, callback, options) {
var that = this;
var error = 1;
var iterations = bucketSize = 0;
var input, output, target, currentRate;
var length = set.length;
var start = Date.now();
if (options) {
if (options.shuffle) {
function shuffle(o) { //v1.0
for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
i), x = o[--i], o[i] = o[j], o[j] = x);
return o;
};
}
if(options.iterations) this.iterations = options.iterations;
if(options.error) this.error = options.error;
if(options.rate) this.rate = options.rate;
if(options.cost) this.cost = options.cost;
if(options.schedule) this.schedule = options.schedule;
if (options.customLog){
// for backward compatibility with code that used customLog
console.log('Deprecated: use schedule instead of customLog')
this.schedule = options.customLog;
}
}
// dynamic learning rate
currentRate = this.rate;
if(Array.isArray(this.rate)) {
bucketSize = Math.floor(this.iterations / this.rate.length);
}
// create a worker
var worker = this.network.worker();
// activate the network
function activateWorker(input)
{
worker.postMessage({
action: "activate",
input: input,
memoryBuffer: that.network.optimized.memory
}, [that.network.optimized.memory.buffer]);
}
// backpropagate the network
function propagateWorker(target){
if(bucketSize > 0) {
var currentBucket = Math.floor(iterations / bucketSize);
currentRate = this.rate[currentBucket];
}
worker.postMessage({
action: "propagate",
target: target,
rate: currentRate,
memoryBuffer: that.network.optimized.memory
}, [that.network.optimized.memory.buffer]);
}
// train the worker
worker.onmessage = function(e){
// give control of the memory back to the network
that.network.optimized.ownership(e.data.memoryBuffer);
if(e.data.action == "propagate"){
if(index >= length){
index = 0;
iterations++;
error /= set.length;
// log
if(options){
if(this.schedule && this.schedule.every && iterations % this.schedule.every == 0)
abort_training = this.schedule.do({
error: error,
iterations: iterations
});
else if(options.log && iterations % options.log == 0){
console.log('iterations', iterations, 'error', error);
};
if(options.shuffle) shuffle(set);
}
if(!abort_training && iterations < that.iterations && error > that.error){
activateWorker(set[index].input);
}
else{
// callback
callback({
error: error,
iterations: iterations,
time: Date.now() - start
})
}
error = 0;
}
else{
activateWorker(set[index].input);
}
}
if(e.data.action == "activate"){
error += that.cost(set[index].output, e.data.output);
propagateWorker(set[index].output);
index++;
}
}
自然語言字符串不應轉換為二進制進行規范化。 改用一鍵編碼:
另外,我建議您看一下Neataptic,而不是Synaptic。 它修復了Synaptic中的許多錯誤,並提供了更多功能供您使用。 在培訓期間,它有一個特殊的選擇,稱為clear
。 這告訴網絡每次訓練迭代都要重置上下文,因此它知道它是從頭開始的。
為什么您的網絡只有1個二進制輸入? 網絡輸入應該有意義。 神經網絡功能強大,但您卻要給他們一個艱巨的任務。
相反,您應該有多個輸入,每個字母一個。 或更理想的情況是,每個單詞一個。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.