[英]insert dtype in std::map
我想做一個映射,它需要一對pybind11::dtype
和int
並將其映射為 OpenCV 格式:
static std::map<std::pair<pybind11::dtype, int>, int> ocv_types;
所以我insert
了所有組合,但是在添加int32_t
和float_t
時似乎存在問題:
ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::int32_t>() , 3), CV_32SC3));
ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::float_t>() , 3), CV_32FC3));
當我這樣做時,只有CV_32SC3
是真正的insert
ed ,我猜在某個地方程序“認為”兩個元素相等,因此不會插入第二個元素。
我該如何添加這兩個?
PS我做這個檢查只是為了“證明”類型不相等:
if(pybind11::dtype::of<std::int32_t>() == pybind11::dtype::of<std::float_t>())
{
std::cout << "std::int32_t == std::float_t" << std::endl;
}
else
{
std::cout << "std::int32_t != std::float_t" << std::endl;
}
...當然他們不是。
編輯
我為dtype
添加了<
函數並將其用於地圖的比較函數中,但並非所有元素都存在於地圖中:
int getVal(pybind11::dtype type)
{
if(type.is(pybind11::dtype::of<std::uint8_t>()))
return 1;
if(type.is(pybind11::dtype::of<std::uint16_t>()))
return 2;
if(type.is(pybind11::dtype::of<std::int16_t>()))
return 3;
if(type.is(pybind11::dtype::of<std::int32_t>()))
return 4;
if(type.is(pybind11::dtype::of<std::float_t>()))
return 5;
if(type.is(pybind11::dtype::of<std::double_t>()))
return 6;
}
inline bool operator <(const pybind11::dtype a, const pybind11::dtype b) //friend claim has to be here
{
return getVal(a) < getVal(b);
}
auto comp = [](const std::pair<pybind11::dtype, int> a, const std::pair<pybind11::dtype, int> b)
{
return a < b;
};
static std::map<std::pair<pybind11::dtype, int>, int, decltype(comp)> ocv_types(comp);
正如您所指出pybind11::dtype
沒有任何特定的順序。 所以 IMO 最好的方法是使用std::unordered_map
並提供各自的哈希值。 pybind11
已經有一些哈希函數,所以需要將它用於std::hash
。
這是我編寫的測試(使用 Catch2),它通過了我的機器:
主.cpp:
#include "catch2/catch_all.hpp"
#include <pybind11/embed.h>
#include <pybind11/numpy.h>
#include <unordered_map>
template<>
struct std::hash<pybind11::dtype>
{
size_t operator()(const pybind11::dtype &t) const
{
return pybind11::hash(t);
}
};
template<>
struct std::hash<std::pair<pybind11::dtype, int>>
{
size_t operator()(const std::pair<pybind11::dtype, int> &t) const
{
return std::hash<pybind11::dtype>{}(t.first) ^ static_cast<size_t>(t.second);
}
};
TEST_CASE("map_with_dtype") {
constexpr auto CV_32SC3 = 1;
constexpr auto CV_32FC3 = 2;
pybind11::scoped_interpreter guard{};
std::unordered_map<std::pair<pybind11::dtype, int>, int> ocv_types;
REQUIRE(ocv_types.empty());
auto a = ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::int32_t>() , 3), CV_32SC3));
REQUIRE(a.second);
auto b = ocv_types.insert(std::make_pair(std::make_pair(pybind11::dtype::of<std::float_t>() , 3), CV_32FC3));
REQUIRE(b.second);
CHECK(b.first->second == CV_32FC3);
CHECK(ocv_types.size() == 2);
}
CMakeLists.txt:
cmake_minimum_required(VERSION 3.16)
# set the project name
project(MapOfPyBind11)
find_package(Catch2 REQUIRED)
find_package(pybind11 REQUIRED)
# add the executable
add_executable(MapOfPyBind11Test main.cpp)
target_link_libraries(MapOfPyBind11Test PRIVATE Catch2::Catch2 pybind11::module pybind11::embed)
include(CTest)
include(Catch)
catch_discover_tests(MapOfPyBind11Test)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.