简体   繁体   中英

LSTMs with multi-dimensional output targets

Given a time series of 3D vectors, [x, y, z] where x, y, and z are arbitrary integers, I'd like to build a model that predicts the next vector in the series and captures patterns in each of the dimensions x, y, z.

So if X = [[0, 0, 6], [1, 0, 0], [9, 9, 9], [3, 0, 3], [1, 2, 3]] and I give my model the 4-element sequence [[0, 0, 6], [1, 0, 0], [9, 9, 9], [3, 0, 3]] it would predict [1, 2, 3]

I can't just one-hot encode each vector since the numbers can have arbitrary values, so I'm wondering how I can accomplish this. Any insight is greatly appreciated, thank you!

Your input, in this case, is just the vector. At timestep 1 the vector is [0,0,6] , at timestep 2 the vector is [1,0,0] , and so on. For the output you are expected to pass the output through a fully connected layer that transforms it to the correct size for output.

Assuming your sequence length is fixed you really don't have any preprocessing to do here, except perhaps standardizing or rescaling your inputs so they aren't very large numbers.

In general, an RNN works a lot like a fully connected network. In fact an RNN cell is made up of 4 fully connected networks that are simply piped together in a non-trivial way. But from the perspective of what you put in and what you get out, think of them like a simple fully connected network (per each timestep).

You can read up more on my last paragraph here: http://colah.github.io/posts/2015-08-Understanding-LSTMs/

If your sequence length is variable then you would typically add an input that flags it as the prediction step. This could simply be all zero's such as:

X = [[0, 0, 6], [1, 0, 0], [9, 9, 9], [3, 0, 3], [0, 0, 0]]

or, if [0,0,0] is a valid datapoint you could add a feature to flag the timestep as an input or a prediction such as:

X = [[0, 0, 0, 6], [0, 1, 0, 0], [0, 9, 9, 9], [0, 3, 0, 3], [1, 0, 0, 0]]

Where the first value in that dataset indicates if the timestep is an input 0 or a prediction 1 .

You will have outputs at each timestep which you will ignore. Your loss function will be based only on the output of the last timestep.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM