简体   繁体   中英

Can't flatten output from Keras model

I have the below model built with Keras, and I am training it using StratifiedKFold. Training works nice, performance is good. Now I am trying to explain the model predictions using the SHAP library. My dateset shape is (107012, 67) and the below is the the code I wrote that encodes my data, trains and makes predictions. original_X is the variable I am reading my data in using Pandas. Most of my data is categorical and only one column contains continuous values.

ohe = OneHotEncoder()
mms = MinMaxScaler()

ct = make_column_transformer(
    (ohe, categorical_columns_encode),
    (mms, numerical_columns_encode),
    remainder='passthrough')

ct.fit(original_X.astype(str))
X = ct.transform(original_X.astype(str))
print(X.shape) # Shape of the encoded value (107012, 47726)

recall = Recall(name="recall")
prec = Precision(name="precision")
ba = BinaryAccuracy()

def get_model():
  network = Sequential()
  network.add(Input(shape=X_1.shape))
  network.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
  network.add(Dropout(0.5))
  network.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
  network.add(Dropout(0.5))
  network.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
  # network.add(Flatten())
  network.add(Dense(1, activation='sigmoid'))

  network.compile(loss='binary_crossentropy',
              optimizer=Adam(learning_rate=0.001),
              metrics=[recall, prec, ba])
  return network

classifier = KerasClassifier(build_fn=get_model)
kfold = RepeatedStratifiedKFold(n_splits=3, n_repeats=3, random_state=42)

callback = EarlyStopping(
    monitor='val_recall',
    min_delta=0,
    patience=0,
    verbose=1,
    mode="auto",
    baseline=None,
    restore_best_weights=True
)

epochs_per_fold = []

for train, validation in kfold.split(X_1, y_1):
  X_train, X_validation = X_1[train], X_1[validation]
  y_train, y_validation = y_1[train], y_1[validation]

  # Printing the distribution of classes in the training set
  counter = Counter(y_train)
  print("Number of class distributions of the training set ", counter)
  print("Minority case percentage of the training set ", counter[1] / (counter[0] + counter[1]))
  
  # Training our model and saving the history of the training
  history = classifier.fit(
    x=X_train,
    y=y_train,
    verbose=1,
    epochs=30,
    shuffle=True,
    callbacks=[callback],
    class_weight={0: 1.0, 1: 3.0},
    validation_data=(X_validation, y_validation))

  # predict classes for our validation set in order to manually verify the metrics
  yhat_classes = (classifier.predict(X_validation) > 0.5).astype("int32")

  TP = 0
  FP = 0
  TN = 0
  FN = 0

  # Record our preditions for the confusion matrix for manually verifying our metrics
  for p,t in zip(y_validation, yhat_classes):
    if p == 1 and t == 1:
      TP += 1
    elif p == 0 and t == 1:
      FP += 1
    elif p == 1 and t == 0:
      FN += 1
    elif p == 0 and t == 0:
      TN += 1
  
  print("\n")
  print(" "*16, "T  F")
  print("Positive result ", TP, FP, )
  print("Negative result ", TN, FN, )
  print("\n")

  # Printing the built in classification report of our model
  print(classification_report(y_validation, yhat_classes))

  report_dict = classification_report(y_validation, yhat_classes, output_dict=True)

  # Record the average number of epochs of training
  epochs_per_fold.append(len(history.history['recall']))
  print(yhat_classes)

Here I am trying to use DeepExplainer from the Shap libeary to look inside my predictions.

# we use the first 100 training examples as our background dataset to integrate over
background = X_2[np.random.choice(X_2.shape[0], 100, replace=False)]

explainer = shap.DeepExplainer(get_model(), background)

When the code reaches the explainer declaration the below error is thrown.

Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode. See PR #1483 for discussion.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-113-d24b2d1e3b91> in <module>()
----> 1 explainer = shap.DeepExplainer(get_model(), background)

1 frames
/usr/local/lib/python3.7/dist-packages/shap/explainers/_deep/deep_tf.py in __init__(self, model, data, session, learning_phase_flags)
    100         self.model_output = _get_model_output(model)
    101         assert type(self.model_output) != list, "The model output to be explained must be a single tensor!"
--> 102         assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"
    103         self.multi_output = True
    104         if len(self.model_output.shape) == 1:

AssertionError: The model output must be a vector or a single value!

My questions are:

  1. How can I flatten the output of my model from within the get_model function?
  2. Is there a better approach to explaining my predictions with Shap?

Let me know if I need to share any extra information on this.

Adding a Flatten layer after a Dense layer is causing the error. Observe, that the line which causes the error is,

assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"        

Considering a 2D input, the output of a Dense layer is ( None, units ) . So, if we have a Dense( 32 ) layer and the batch size is set to 16, then the output of such a layer will be a tensor of shape ( 16, 32 ) . The Flatten layer preserves the 0th axis ( ie the batch dimension ), and hence a tensor of shape ( 16, 32 ) could be flattened further.

On the other hand, if you had a tensor of shape ( 16, 32, 3 ) ( for ex. output of a Conv2D layer with 3 filters ) then the output of the Flatten layer will be a tensor of shape ( 16, 96 ) .

Since you have 2D input, just remove the Flatten layer. If you were trying to reshape the output, use a Reshape layer instead.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM