简体   繁体   中英

MultiOutput Classification with TensorFlow Extended (TFX)

I'm quite new to TFX (TensorFlow Extended), and have been going through the sample tutorial on the TensorFlow portal to understand a bit more to apply it to my dataset.

In my scenario, instead of predicting a single label, the problem at hand requires me to predict 2 outputs (category 1, category 2).

I've done this using pure TensorFlow Keras Functional API and that works fine, but then am now looking to see if that can be fitted into the TFX pipeline.

Where i get the error, is at the Trainer stage of the pipeline, and where it throws the error is in the _input_fn , and i suspect it's because i'm not correctly splitting out the given data into (features, labels) tensor pair in the pipeline.

Scenario:

  1. Each row of the input data comes in the form of [Col1, Col2, Col3, ClassificationA, ClassificationB]

  2. ClassificationA and ClassificationB are the categorical labels which i'm trying to predict using the Keras Functional Model

The output layer of the keras functional model looks like below, where there's 2 outputs that is joined to a single dense layer (Note: _xf appended to the end is just to illustrate that i've encoded the classes to int representations)

output_1 = tf.keras.layers.Dense( TargetA_Class, activation='sigmoid', name = 'ClassificationA_xf')(dense)

output_2 = tf.keras.layers.Dense( TargetB_Class, activation='sigmoid', name = 'ClassificationB_xf')(dense)

model = tf.keras.Model(inputs = inputs, outputs = [output_1, output_2])

In the trainer module file, i've imported the required packages at the start of the module file >

import tensorflow_transform as tft
from tfx.components.tuner.component import TunerFnResult
import tensorflow as tf
from typing import List, Text
from tfx.components.trainer.executor import TrainerFnArgs
from tfx.components.trainer.fn_args_utils import DataAccessor, FnArgs
from tfx_bsl.tfxio import dataset_options

The current input_fn in the trainer module file looks like the below (by following the tutorial)

def _input_fn(file_pattern: List[Text],
              data_accessor: DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Helper function that Generates features and label dataset for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
      
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      dataset_options.TensorFlowDatasetOptions(
          batch_size=batch_size, 
          #label_key=[_transformed_name(x) for x in _CATEGORICAL_LABEL_KEYS]),
          label_key=_transformed_name(_CATEGORICAL_LABEL_KEYS[0]), _transformed_name(_CATEGORICAL_LABEL_KEYS[1])),
      tf_transform_output.transformed_metadata.schema)

When i run the trainer component the error that comes up is:

label_key=_transformed_name(_CATEGORICAL_LABEL_KEYS[0]),transformed_name(_CATEGORICAL_LABEL_KEYS 1 )),

^ SyntaxError: positional argument follows keyword argument

I've also tried label_key=[_transformed_name(x) for x in _CATEGORICAL_LABEL_KEYS]) which also gives an error.

However, if i just pass in a single label key, label_key=transformed_name(_CATEGORICAL_LABEL_KEYS[0]) then it works fine.

FYI - _CATEGORICAL_LABEL_KEYS is nothing but a list which contains the names of the 2 outputs i'm trying to predict (ClassificationA, ClassificationB).

transformed_name is nothing but a function to return an updated name/key for the transformed data:

def transformed_name(key):
  return key + '_xf'

Question:

From what i can see, the label_key argument for dataset_options.TensorFlowDatasetOptions can only accept a single string/name of label, which means it may not be able to output the dataset with multi labels.

Is there a way which i can modify the _input_fn so that i can get the dataset that's returned by _input_fn to work with returning the 2 output labels? So the tensor that's returned looks something like:

Feature_Tensor: {Col1_xf: Col1_transformedfeature_values, Col2_xf: Col2_transformedfeature_values, Col3_xf: Col3_transformedfeature_values}

Label_Tensor: {ClassificationA_xf: ClassA_encodedlabels, ClassificationB_xf: ClassB_encodedlabels}

Would appreciate advice from the wider community of tfx!

Since the label key is optional, maybe instead of specifying it in the TensorflowDatasetOptions, instead you can use dataset.map afterwards and pass both labels after taking them from your dataset.

Haven't tested it but something like:

def _data_augmentation(feature_dict):
  features = feature_dict[_transformed_name(x) for x in 
  _CATEGORICAL_FEATURE_KEYS]]
  keys=[_transformed_name(x) for x in _CATEGORICAL_LABEL_KEYS]

  return features, keys
  

def _input_fn(file_pattern: List[Text],
              data_accessor: DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Helper function that Generates features and label dataset for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
      
  """
  dataset = data_accessor.tf_dataset_factory(
      file_pattern,
      dataset_options.TensorFlowDatasetOptions(
        batch_size=batch_size, 
        tf_transform_output.transformed_metadata.schema)

  dataset = dataset.map(_data_augmentation)
  return dataset

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM