简体   繁体   中英

Any way to efficiently stack/ensemble pre-trained models for image classification?

I am trying to stack a few pre-trained models that I have through taking the last hidden layer of each model and then concatenating them together and then plugging them into a meta-learner model (eg XGBoost).

I am running into a big problem of having to process each image of my dataset multiple times since each base model requires a different processing method. This is causing my model to take a really long time to train and is infeasible. Is there any way to work past this?

For example:

model_1, processor_1 = pretrained_model(), pretrained_processor()
model_2, processor_2 = pretrained_model2(), pretrained_processor2()

for img in images:

input_1 = processor_1(img)
input_2 = processor_2(img)

out_1 = model_1(input_1)
out_2 = model_2(input_2)

torch.cat((out1,out2), dim=1) #concatenates hidden representations to feed into another model

Here'a recommendation if you want to process your images faster:

Note: I did not test this out

import torch
import torch.nn as nn

# Create a stack nn module
class StackedModel(nn.Module):
  def __init__(self, model1, model2):
    super(StackedModel, self).__init__()

    self.model1 = model1
    self.model2 = model2

  def forward(self, imgs):
    out_1 = model_1(input_1)
    out_2 = model_2(input_2)

    return torch.cat((out1, out2), dim=1)

# Init model
model = StackedModel(model1, model2)

# Try to stack and run in a larger batch assuming u have extra gpu space
stacked_preproc1 = []
stacked_preproc2 = []
max_batch_size = 16
total_output = []

for index, img in enumerate(images):
  input_1 = processor_1(img)
  input_2 = processor_2(img)

  stacked_preproc1.append(input_1)
  stakced_preproc2.appennd(input2)

  if index % max_batch_size == 0:
    stacked_preproc1 = torch.stack(stacked_preproc1)
    stakced_preproc2 = torch.stack(stakced_preproc2)
  else:
    total_output.append(
        model(stacked_preproc1, stacked_preproc2)
    )

    # Reset array
    stacked_preproc1 = []
    stakced_preproc2 = []
      

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM