I am looking through the example of deeplearning 4j for classifying movie reviews according to their sentiment. ReviewExample
At line 124-142 the N-dimensional arrays are created and I am kind of unsure what is happening at these lines:
Line 132:
features.put(new INDArrayIndex[]{NDArrayIndex.point(i),
NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
I can image that .point(x)
and .point(j)
address the cell in the array, but what exactly does the NDArrayIndex.all()
call do here?
While building the feature array is more or less ok what is happening there I get totally confused by the label mask and this lastIdx
variable
Line 138 - 142
int idx = (positive[i] ? 0 : 1);
int lastIdx = Math.min(tokens.size(),maxLength);
labels.putScalar(new int[]{i,idx,lastIdx-1},1.0); //Set label: [0,1] for negative, [1,0] for positive
labelsMask.putScalar(new int[]{i,lastIdx-1},1.0); //Specify that an output exists at the final time step for this example
The label array itself is addressed by i, idx
eg column/row that is set to 1.0 - but I don't really get how this time-step information fits in? Is this conventional that the last parameter has to mark the last entry?
Then why does the labelsMask use only i
and not i, idx
?
Thanks for explanations or pointer that help to clarify some of my questions
It's an index per dimension. All() is an indicator (use this whole dimension). See the nd4j user guide: http://nd4j.org/userguide
As for the 1. That 1 is meant to be the class for the label there. It's a text classification problem: Take the window from the text and word vectors and have the class be predicted from that.
As for the label mask: The prediction of a neural net happens at the end of a sequence. See: http://deeplearning4j.org/usingrnns
write a test and you will know it.
val features = Nd4j.zeros(2, 2, 3) val toPut = Nd4j.ones(2) features.put(Array[INDArrayIndex](NDArrayIndex.point(0), NDArrayIndex.all, NDArrayIndex.point(1)), toPut)
the result is [[[0.00, 1.00, 0.00],
[0.00, 1.00, 0.00]],
[[0.00, 0.00, 0.00],
[0.00, 0.00, 0.00]]] it will put the 'toPut' vector to the features.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.