简体   繁体   中英

Converting a pandas dataframe to a list of lists for input into an RNN

In Python, I have a dataframe that I imported with pandas.read_csv that looks like this as an example:

Cust_id| time_to_event_f |event_id |event_sub_id

1       100             5 2  
1       95              1 3  
1       44              3 1  
2       99              5 5  
2       87              2 2  
2       12              3 3  

The data are ordered by cust_id and then time_to_event_f . I am trying to convert this dataframe into a tensor of dimensions [2,3,3] so that for each customer id I have a sequential list of time_to_event_f , event_id , and event_sub_id . The idea is to use this as an input into an RNN in tensorflow. I am following this tutorial so I am trying to get my data in a similar format.

You can transform the original dataframe d to customer-id centered series by setting a Cust_id index and then stacking:

d.set_index('Cust_id').stack()

The result series will look like this:

Cust_id                 
1        time_to_event_f    100
         event_id             5
         event_sub_id         2
         time_to_event_f     95
         event_id             1
         event_sub_id         3
         time_to_event_f     44
         event_id             3
         event_sub_id         1
2        time_to_event_f     99
         event_id             5
         event_sub_id         5
         time_to_event_f     87
         event_id             2
         event_sub_id         2
         time_to_event_f     12
         event_id             3
         event_sub_id         3
dtype: int64

Given this representation, you task is easy: take the values ndarray and reshape it to your target size:

series.values.reshape([2, 3, 3])

This array can be fed as input to tensorflow RNN. A complete code below:

import pandas as pd
from io import StringIO

s = StringIO("""
1       100             5 2  
1       95              1 3  
1       44              3 1  
2       99              5 5  
2       87              2 2  
2       12              3 3
""".strip())

d = pd.read_table(s, names=['Cust_id', 'time_to_event_f', 'event_id', 'event_sub_id'], sep=r'\s+')
series = d.set_index('Cust_id').stack()
time_array = series.values.reshape([2, 3, 3])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM