繁体   English   中英

PyTorch onnx 中的规范化 model

[英]PyTorch normalization in onnx model

我在 pytorch 做图像分类,在那,我用了这个变换

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

并完成了培训。 之后,我将.pth model 文件转换为.onnx 文件

现在,在推论中,我应该如何在 numpy 数组中应用此转换,因为 onnx 处理 numpy 数组中的输入

例如,您可以将相同的transforms应用于np.array

你有几个选择。

由于 normalize 自己编写非常简单,您可以这样做

import numpy as np
mean = np.array([0.485, 0.456, 0.406]).reshape(-1,1,1)
std = np.array([0.229, 0.224, 0.225]).reshape(-1,1,1)
x_normalized = (x - mean) / std

根本不需要 pytorch 或 torchvision 库。

如果您仍在使用 pytorch 数据集,则可以使用以下转换

transforms.Compose([
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    torch.Tensor.numpy  # or equivalently transforms.Lambda(lambda x: x.numpy())
])

这只会将归一化应用于张量,然后将其转换为 numpy 数组。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM