简体   繁体   English

Keras - 人工神经网络 - 使用自定义激活时出错 function

[英]Keras – Artificial Neural Networks - Error when using a custom activation function

I'm creating an Artificial Neural Network (ANN) using Kera's Functional API.我正在使用 Kera 的功能 API 创建一个人工神经网络 (ANN)。 Link to the data csv file: https://github.com/dpintof/SPX_Options_ANN/blob/master/MLP3/call_df.csv .链接到数据 csv 文件: https://github.com/dpintof/SPX_Options_ANN/blob/master/MLP3/call_df.Z628CB5675FF524F3E719B7AA2E88FE3 . Relevant part of the code that reproduces problem:重现问题的代码的相关部分:

import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers


# Data 
call_df = pd.read_csv("call_df.csv")

call_X_train, call_X_test, call_y_train, call_y_test = train_test_split(call_df.drop(["Option_Average_Price"],
                    axis = 1), call_df.Option_Average_Price, test_size = 0.01)


# Hyperparameters
n_hidden_layers = 2 # Number of hidden layers.
n_units = 128 # Number of neurons of the hidden layers.

# Create input layer
inputs = keras.Input(shape = (call_X_train.shape[1],))
x = layers.LeakyReLU(alpha = 1)(inputs)

"""
Function that creates a hidden layer by taking a tensor as input and applying a
modified ELU (MELU) activation function.
"""
def hl(tensor):
    # Create custom MELU activation function
    def melu(z):
        return tf.cond(z > 0, lambda: ((z**2)/2 + 0.02*z) / (z - 2 + 1/0.49), 
                        lambda: 0.49*(keras.activations.exponential(z)-1))
    
    y = layers.Dense(n_units, activation = melu)(tensor)
    return y

# Create hidden layers
for _ in range(n_hidden_layers):
    x = hl(x)

# Create output layer
outputs = layers.Dense(1, activation = keras.activations.softplus)(x)

# Actually create the model
model = keras.Model(inputs=inputs, outputs=outputs)


# QUICK TEST
model.compile(loss = "mse", optimizer = keras.optimizers.Adam())
history = model.fit(call_X_train, call_y_train, 
                    batch_size = 4096, epochs = 1,
                    validation_split = 0.01, verbose = 1)

This is the error I get when I do model.fit(…) (notice that 4096 is my batch size and 128 is the number of neurons of the hidden layers):这是我在执行 model.fit(...) 时得到的错误(注意 4096 是我的批量大小,128 是隐藏层的神经元数量):

InvalidArgumentError:  The second input must be a scalar, but it has shape [4096,128]
     [[{{node dense/cond/dense/BiasAdd/_5}}]] [Op:__inference_keras_scratch_graph_1074]

Function call stack:
keras_scratch_graph

I know the problem has to do with the custom activation function because the program runs fine if I use the following hl function instead:我知道问题与自定义激活 function 有关,因为如果我使用以下 hl function 代替,程序运行良好:

def hl(tensor):
    lr = layers.Dense(n_units, activation = layers.LeakyReLU())(tensor)
    return lr

I got the same error when trying to define melu(z) like this:尝试像这样定义 melu(z) 时遇到同样的错误:

@tf.function
def melu(z):
    if z > 0:
        return ((z**2)/2 + 0.02*z) / (z - 2 + 1/0.49)
    else:
        return 0.49*(keras.activations.exponential(z)-1)

From How do you create a custom activation function with Keras?如何使用 Keras 创建自定义激活 function? I also tried the following, but without success:我也尝试了以下方法,但没有成功:

def hl(tensor):
    # Create custom MELU activation function
    def melu(z):
        return tf.cond(z > 0, lambda: ((z**2)/2 + 0.02*z) / (z - 2 + 1/0.49), 
                        lambda: 0.49*(keras.activations.exponential(z)-1))
    
    from keras.utils.generic_utils import get_custom_objects
    get_custom_objects().update({'melu': layers.Activation(melu)})
 
    x = layers.Dense(n_units)(tensor)
    y = layers.Activation(melu)(x)
    return y

This issue happens because tf.cond expects a scalar for the condition argument (instead of a multi-dimensional tensor).发生此问题是因为tf.cond需要条件参数的标量(而不是多维张量)。 Instead, you can use tf.where to apply the conditional element-wise.相反,您可以使用tf.where应用条件元素。

For example, you can define melu as follows:例如,您可以如下定义melu

def melu(z):
    return tf.where(z > 0, ((z**2)/2 + 0.02*z) / (z - 2 + 1/0.49), 
                           0.49*(keras.activations.exponential(z)-1))

NOTE: Not tested.注意:未经测试。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM