简体   繁体   中英

How to restrict the sequence prediction in an LSTM model to match a specific pattern?

I have created a word-level text generator using an LSTM model. But in my case, not every word is suitable to be selected. I want them to match additional conditions:

  1. Each word has a map: if a character is a vowel then it will write 1 if not, it will write 0 (for instance, overflow would be 10100010 ). Then, the sentence generated needs to meet a given structure, for instance, 01001100 ( hi 01 and friend 001100 ).
  2. The last vowel of the last word must be the one provided. Let's say is e . ( fri e nd will do the job, then).

Thus, to handle this scenario, I've created a pandas dataframe with the following structure:

word    last_vowel  word_map
-----   ---------   ----------
hello   o           01001
stack   a           00100
jhon    o           0010

This is my current workflow:

  1. Given the sentence structure, I choose a random word from the dataframe which matches the pattern. For instance, if the sentence structure is 0100100100100 , we can choose the word hello , as its vowel structure is 01001 .
  2. I subtract the selected word from the remaining structure: 0100100100100 will become 00100100 as we've removed the initial 01001 ( hello ).
  3. I retrieve all the words from the dataframe which matches part of the remaining structure, in this case, stack 00100 and jhon 0010 .
  4. I pass the current word sentence content (just hello by now) to the LSTM model, and it retrieves the weights of each word.
  5. But I don't just want to select the best option, I want to select the best option contained in the selection of point 3. So I choose the word with the highest estimation within that list, in this case, stack .
  6. Repeat from point 2 until the remaining sentence structure is empty.

That works like a charm, but there is one remaining condition to handle: the last vowel of the sentence.

My way to deal with this issue is the following:

  1. Generating 1000 sentences forcing that the last vowel is the one specified.
  2. Get the rmse of the weights returned by the LSTM model. The better the output, the higher the weights will be.
  3. Selecting the sentence which retrieves the higher rank.

Do you think is there a better approach? Maybe a GAN or reinforcement learning?

EDIT: I think another approach would be adding WFST. I've heard about pynini library , but I don't know how to apply it to my specific context.

If you are happy with your approach, the easiest way might be if you'd be able to train your LSTM on the reversed sequences as to train it to give the weight of the previous word, rather than the next one. In such a case, you can use the method you already employ, except that the first subset of words would be satisfying the last vowel constraint. I don't believe that this is guaranteed to produce the best result.

Now, if that reversal is not possible or if, after reading my answer further, you find that this doesn't find the best solution, then I suggest using a pathfinding algorithm, similar to reinforcement learning, but not statistical as the weights computed by the trained LSTM are deterministic. What you currently use is essentially a depth first greedy search which, depending on the LSTM output, might be even optimal. Say if LSTM is giving you a guaranteed monotonous increase in the sum which doesn't vary much between the acceptable consequent words (as the difference between N-1 and N sequence is much larger than the difference between the different options of the Nth word). In the general case, when there is no clear heuristic to help you, you will have to perform an exhaustive search. If you can come up with an admissible heuristic, you can use A* instead of Dijkstra's algorithm in the first option below, and it will do the faster, the better you heuristic is.

I suppose it is clear, but just in case, your graph connectivity is defined by your constraint sequence. The initial node (0-length sequence with no words) is connected with any word in your data frame that matches the beginning of your constraint sequence. So you do not have the graph as a data structure, just it's the compressed description as this constraint.

EDIT As per request in the comment here are additional details. Here are a couple of options though:

  1. Apply Dijkstra's algorithm multiple times. Dijkstra's search finds the shortest path between 2 known nodes, while in your case we only have the initial node (0-length sequence with no words) and the final words are unknown.

    • Find all acceptable last words (those that satisfy both the pattern and vowel constraints).
    • Apply Dijkstra's search for each one of those, finding the largest word sequence weight sum for each of them.
    • Dijkstra's algorithm is tailored to the searching of the shortest path, so to apply it directly you will have to negate the weights on each step and pick the smallest one of those that haven't been visited yet.
    • After finding all solutions (sentences that end with one of those last words that you identified initially), select the smallest solution (this is going to be exactly the largest weight sum among all solutions).
  2. Modify your existing depth-first search to do an exhaustive search.

    • Perform the search operation as you described in OP and find a solution if the last step gives one (if the last word with a correct vowel is available at all), record the weight
    • Rollback one step to the previous word and pick the second-best option among previous words. You might be able to discard all the words of the same length on the previous step if there was no solution at all. If there was a solution, it depends on whether your LSTM provides different weights depending on the previous word. Likely it does and in that case, you have to perform that operation for all the words in the previous step.
    • When you run out of the words on the previous step, move one step up and restart down from there.
    • You keep the current winner all the time as well as the list of unvisited nodes on every step and perform exhaustive search. Eventually, you will find the best solution.

I would reach for a Beam Search here.

This is much like your current approach of starting 1000 solutions randomly. But instead of expanding each of those paths independently, it expands all candidate solutions together in a step by step manner.

Sticking with the current candidate count of 1000, it would look like this:

  1. Generate 1000 stub solutions, for example using random starting points or selected from some "sentence start" model.
  2. For each candidate, compute the best extensions based on your LSTM language model, which fit the constraints. This works just as it does in your current approach, except you could also try more than one option. For example using the best 5 choices for the next word would product 5000 child candidates.
  3. Compute a score for each of those partial solution candidates, then reduce back to 1000 candidates by keeping only the best scoring options.
  4. Repeat steps 2 and 3 until all candidates cover the full vowel sequence, including the end constraint.
  5. Take the best scoring of these 1000 solutions.

You can play with the candidate scoring to trade off completed or longer solutions vs very good but short fits.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM