简体   繁体   中英

how to get the output of a CNN with same dimension as the input

I have gray scale images which I got their arrays of pixels in x_train and x_test.

x_train is of size (2500, 21, 512) and x_test of size (500, 21, 512). I want to do a CNN to get as output y_train as also (2500,21,512) and y_test as (500,21,512) but which are the arrays of other images that I want the network to predict.

In the MNIST they do it but by taking y_train and y_test as a vector of values and then take the output as (3000, 1). How could I do the same but for my images?

Hmmmm I don't fully understand your question, but I will take a stab. Please let me know if I misinterpreted your question.

Your model takes the following input:

x_train: the image. 

And outputs:

x_hat = an image with the same dimensions as `x_train`

Judging by the described architecture, it seems like you are building a convolutional autoencoder. Am I correct?

If so, you have to do the following:

  1. You need to add a channel of dimension one so that the CNN can receive the input, which can be done by reshaping the tensor. Convolutional neural network inputs are as follows: (batch_size, channels, width, height) .If you don't want to add a channel, you can use a simple feed-forward neural network (or MLP). If this is the case, you will still have to flatten the inputs into the following dimension: (batch_size, pixels) . For a more concrete example, given the mnist dataset, if the batch_size is 32, your input dimension will be (32, 784) , since MNIST images are 28 x 28. By flattening the image, you get input size of 784.

  2. You can create a convolutional autoencoder by doing strided convolutions to downsample the images in the encoder layers. Afterwards, you can take the intermediate representation and do an upsampling operation via transposed convolutions. If you want to train a model that can actually generate samples instead of reconstructing, I recommend looking up variational autoencoders and generative adversarial networks.

The implementation will vary depending on the framework (Eg PyTorch , TensorFlow , etc.)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM