简体   繁体   中英

Tensorflow: Create a tensor with varying shapes in row dimension

I have a pandas dataframe with a structure similar to the following:

在此处输入图像描述

I want to create a tensor 3D tensor where each element has a varying number of rows. Using the dataframe above as an example, the first element would correspond to index (key_value) 4 and have shape (10, 3). The second element would correspond to key_value 5 and have shape (6,3). The third element would correspond to key_value 6 and have shape (3,3). Same thing for the last element.

My goal is to create a model that analyzes the data (which is sequential) and makes a classification (binary 1 or 0 - target value not shown). I've read that structures such as RaggedTensors could achieve this, but I haven't figured out how to go from the dataframe to the tensor.

How should I go about creating a the data structure as mentioned above to pass it into a model that may have CNN layers and/or LSTM layers? It is important to have a 2D array per index.

I appreciate your help!

Example code of convert your Dataframe to RaggedTensors :

import tensorflow as tf
import pandas as pd
import numpy as np

d = {'key_value': [4, 4, 4, 5, 5, 6], 
     'var0': [0.1, 0.11, 0.111, 0.2, 0.22, 0.3], 
     'var1': [1.1, 1.11, 1.111, 1.2, 1.22, 1.3], 
     'var2': [2.1, 2.11, 2.111, 2.2, 2.22, 2.3]}
df = pd.DataFrame(d)

dfs = dict(tuple(df.groupby('key_value')))
dfs = [dfs[k].drop('key_value', axis=1).to_numpy() for k in dfs]
rt = tf.ragged.constant(dfs)
print("RaggedTensors:\n{}\n".format(rt))

dataset = tf.data.Dataset.from_tensor_slices(rt)
print("Batchs: ")
for feature_batch in dataset:
    print(feature_batch)

Outputs:

RaggedTensors:
<tf.RaggedTensor [[[0.1, 1.1, 2.1], [0.11, 1.11, 2.11], [0.111, 1.111, 2.111]], [[0.2, 1.2, 2.2], [0.22, 1.22, 2.22]], [[0.3, 1.3, 2.3]]]>

Batchs:
<tf.RaggedTensor [[0.1, 1.1, 2.1], [0.11, 1.11, 2.11], [0.111, 1.111, 2.111]]>
<tf.RaggedTensor [[0.2, 1.2, 2.2], [0.22, 1.22, 2.22]]>
<tf.RaggedTensor [[0.3, 1.3, 2.3]]>

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM