[英]How to concatenate LibTorch tensors created with a multi-thread process std::thread in C++?
C++ 中的多線程進程返回張量,我想將它們按順序連接成一個張量。
在 C++ 我有一個 function 返回一個 1x8 張量。
我用std::thread
同時多次調用這個 function 並且我想將它返回的張量連接成一個大張量。 例如,我稱它為 12 次,我希望它完成后會有一個 12x8 的張量。
我需要將它們按順序連接起來,也就是說,用 0 調用的張量應該總是第 0 個 position 中的 go ,然后是第 1 個 Z4757FE07FD492A8BE0EA6A760 中的第一個 Z4757FE07FD492A8BE0EAZ,以此類推。
我知道我可以讓 function 返回一個 12x8 張量,但我需要解決如何獲取多線程過程中產生的張量的問題。
在下面的嘗試中,我嘗試將張量連接到all_episode_steps
張量中,但這會返回錯誤。
如果您注釋掉all_episode_steps
行並將std::cout << one;
在返回語句上方的get_tensors
function 中,您會看到它似乎使用多線程來毫無問題地創建張量。
#include <torch/torch.h>
torch::Tensor get_tensors(int id) {
torch::Tensor one = torch::rand({8});
return one.unsqueeze(0);
}
torch::Tensor all_episode_steps;
int main() {
std::thread ths[100];
for (int id=0; id<12; id++) {
ths[id] = std::thread(get_tensors, id);
all_episode_steps = torch::cat({ths[id], all_episode_steps});
}
for (int id=0; id<12; id++) {
ths[id].join();
}
}
如果您想自己構建它,您可以在此處安裝 LibTorch 。
下面是上面代碼的 CMakeLists.txt 文件。
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
if (MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET example-app
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:example-app>)
endif (MSVC)
線程不能返回張量,但可以通過指針修改張量。 試試這個(未經測試,可能需要一些調整):
void get_tensors(torch::Tensor* out) {
torch::Tensor one = torch::rand({8});
*out = one.unsqueeze(0);
}
int main() {
std::thread ths[12];
std::vector<torch::Tensor> results(12);
for (int id=0; id<12; id++) {
ths[id] = std::thread(get_tensors, &results[id]);
}
for (int id=0; id<12; id++) {
ths[id].join();
}
auto result2d = torch::cat(results);
}
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.