简体   繁体   English

.h5 keras model 到 coreml 的分类转换在 IOS 中不起作用

[英].h5 keras model to coreml conversion for classification does not work in IOS

I trained a CNN classification model using RGB images as input and it produces 1x7 output with probabilities of class labels(7 different classes).我使用 RGB 图像作为输入训练了一个 CNN 分类 model,它产生 1x7 output,概率为 class 标签(7 个不同的类别)。 I have converted the model from keras.h5 to coreML.我已将 model 从 keras.h5 转换为 coreML。 I have seen different applications and tried both of them with and without class labels defined.我见过不同的应用程序,并在定义和不定义 class 标签的情况下尝试了它们。 They did not cause any issue while converting.他们在转换时没有引起任何问题。 However none of them work in IOS.但是它们都不能在 IOS 中工作。 Both models crash when I call below line:当我调用以下行时,两种模型都会崩溃:

 guard let result = predictionRequest.results as? [VNCoreMLFeatureValueObservation] else {
        fatalError("model failed to process image")
    }

Output definition of my both models are below.我的两种型号的 Output 定义如下。 Could you please advice what is wrong with the model output.您能否告知 model output 有什么问题。 Do I have to add class labels or not?我是否必须添加 class 标签? I am confused how to call the highest probable value.我很困惑如何调用最高可能值。 I have added entire classification code too.我也添加了整个分类代码。 Please see below.请看下文。 Since I am a beginner in IOS, your help is greatly appreciated.由于我是 IOS 的初学者,非常感谢您的帮助。 Thanks a lot indeed.确实非常感谢。

Model output definition in IOS with class labels conversion: Model output 定义在 IOS 与 class 标签转换:

/// Identity as dictionary of strings to doubles
lazy var Identity: [String : Double] = {
    [unowned self] in return self.provider.featureValue(for: "Identity")!.dictionaryValue as! [String : Double]
}()


/// classLabel as string value
lazy var classLabel: String = {
    [unowned self] in return self.provider.featureValue(for: "classLabel")!.stringValue
}()

Model output definition in IOS without class labels conversion: Model output 定义在 IOS 没有 class 标签转换:

init(Identity: MLMultiArray) {
    self.provider = try! MLDictionaryFeatureProvider(dictionary: ["Identity" : MLFeatureValue(multiArray: Identity)])
}

Classification Code:分类代码:

class ColorStyleVisionManager: NSObject {
static let shared = ColorStyleVisionManager()
static let MODEL = hair_color_class_labels().model
var colorStyle = String()
var hairColorFlag: Int = 0
private lazy var predictionRequest: VNCoreMLRequest = {
    do{
        let model = try VNCoreMLModel(for: ColorStyleVisionManager.MODEL)
       
        let request = VNCoreMLRequest(model: model)
        request.imageCropAndScaleOption = VNImageCropAndScaleOption.centerCrop
        return request
    } catch {
        fatalError("can't load Vision ML Model")
    }
}()


func predict(image:CIImage) -> String {
    

    guard let result = predictionRequest.results as? [VNCoreMLFeatureValueObservation] else {
        fatalError("model failed to process image")
    }
    
    let firstResult = result.first


    if firstResult?.featureName == "0" {
            colorStyle = "Plain Coloring"
            hairColorFlag = 1
        }
        else if firstResult?.featureName == "1" {
            colorStyle = "Ombre"
            hairColorFlag = 2
        }
        else if firstResult?.featureName == "2" {
            colorStyle = "Sombre"
            hairColorFlag = 2
        }
        else if firstResult?.featureName == "3" {
            colorStyle = "HighLight"
            hairColorFlag = 3
        }
        else if firstResult?.featureName == "4" {
            colorStyle = "LowLight"
            hairColorFlag = 3
        }
        else if firstResult?.featureName == "5" {
            colorStyle = "Color Melt"
            hairColorFlag = 5
        }
        else if firstResult?.featureName == "6" {
            colorStyle = "Dip Dye"
            hairColorFlag = 4
        }

    else {}

    let handler = VNImageRequestHandler(ciImage: image)


    do {
            try handler.perform([predictionRequest])
        } catch {
            print("error handler")
        }

    
    return colorStyle
}

} }

I have found out two different problems in my code.我在我的代码中发现了两个不同的问题。 In order to ensure that my model correctly converted to mlmodel, I created a new classification mlmodel by using Apple's CreateML tool.为了确保我的 model 正确转换为 mlmodel,我使用 Apple 的 CreateML 工具创建了一个新的分类 mlmodel。 By the way it is fantastic even though the accuracy seems lower than my original model.顺便说一句,即使准确度似乎低于我原来的 model,它也很棒。 I compared the output and input types of the model and seems my mlmodel is correct too.我比较了 output 和 model 的输入类型,似乎我的 mlmodel 也是正确的。 Then I used this model and gave it another try.然后我使用了这个 model 并再次尝试。 It crashed again.它又崩溃了。 I wasn't so sure what prediction result I have to expect whether "VNClassificationObservation" or "VNCoreMLFeatureValueObservation".我不太确定我必须期待“VNClassificationObservation”还是“VNCoreMLFeatureValueObservation”的预测结果。 I changed to classificationobservation.我改为分类观察。 It crashed again.它又崩溃了。 Then I realized that my handler definition was below the crash line and I moved it to upper portion.然后我意识到我的处理程序定义在崩溃线下方,我将它移到了上部。 Then woola.然后是羊毛。 It worked.有效。 I double checked by changing the FeatureValueObservation and it crashed again.我通过更改 FeatureValueObservation 进行了仔细检查,它再次崩溃了。 So two problems are solved.这样两个问题就解决了。 Please see the correct code below.请参阅下面的正确代码。

I strongly recommend to use CreateML tool to confirm your model conversion work fine for debugging purposes.我强烈建议使用 CreateML 工具来确认您的 model 转换工作正常以用于调试目的。 It is just a few minutes job.这只是几分钟的工作。

class ColorStyleVisionManager: NSObject {
static let shared = ColorStyleVisionManager()
static let MODEL = hair_color_class_labels().model
var colorStyle = String()
var hairColorFlag: Int = 0
private lazy var predictionRequest: VNCoreMLRequest = {
    do{
        let model = try VNCoreMLModel(for: ColorStyleVisionManager.MODEL)
       
        let request = VNCoreMLRequest(model: model)
        request.imageCropAndScaleOption = VNImageCropAndScaleOption.centerCrop
        return request
    } catch {
        fatalError("can't load Vision ML Model")
    }
}()


func predict(image:CIImage) -> String {
    
    let handler = VNImageRequestHandler(ciImage: image)


    do {
            try handler.perform([predictionRequest])
        } catch {
            print("error handler")
        }

    guard let result = predictionRequest.results as? [VNClassificationObservation] else {
        fatalError("error to process request")
    }
    
    let firstResult = result.first
    print(firstResult!)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM