![](/img/trans.png)
[英]How to check if a class has specific function to avoid compile time issues?
[英]std::function has performances issues, how to avoid it?
我有一些類允許復合協方差函數 (也稱為內核參見https://stats.stackexchange.com/questions/228552/covariance-functions-or-kernels-what-exactly-are-they ),然后計算給出的協方差以新內核為例:
auto C = GaussianKernel(50,60) + GaussianKernel(100,200);
auto result = C.covarianceFunction(30.0,40.0);
但問題是,當我想計算協方差時,我調用了一個std::function
, 有一種簡單的方法可以避免它嗎?
請注意,我想計算一個大的協方差矩陣(大約50K * 50K),這意味着性能很重要。
這是代碼
class Kernel {
public:
/*
Covariance function : return the covariance between two R.V. for the entire kernel's domain definition.
*/
virtual double covarianceFunction(
double X,
double Y
)const = 0 ;
~Kernel() = default;
};
class FooKernel : public Kernel {
public:
FooKernel(std::function<double(double, double)> fun) : fun_(fun) {}
double covarianceFunction(
double X,
double Y
) const {
return fun_(X, Y);
}
template<class T>
auto operator+(const T b) const {
return FooKernel([b, this](double X, double Y) -> double {
return this->covarianceFunction(X, Y) + b.covarianceFunction(X, Y);
});
}
FooKernel operator=(const FooKernel other) const {
return other;
}
private:
std::function<double(double, double)> fun_;
};
class GaussianKernel : public Kernel {
public:
GaussianKernel(double sigma, double scale) : m_sigma(sigma), m_scale(scale) {}
GaussianKernel(double sigma) : m_sigma(sigma), m_scale(1) {}
/*
A well known covariance function that enforces smooth deformations
Ref : Shape modeling using Gaussian process Morphable Models, Luethi et al.
*/
double covarianceFunction(
double X,
double Y
) const
{
//use diagonal matrix
doulbe result;
result = m_scale * exp(-std::norm(X - Y) / (m_sigma*m_sigma));
return result;
}
template<class T>
auto operator+(const T b) const {
return FooKernel([b, this](double X, double Y) -> double {
auto debugBval = b.covarianceFunction(X, Y);
auto debugAval = this->covarianceFunction(X, Y);
auto test = debugBval + debugAval;
return test;
});
}
private:
double m_sigma;
double m_scale;
};
通過模板化FooKernel,您可以避免使用std :: function。
#include <iostream>
#include <complex>
#include <functional>
class Kernel {
public:
/*
Covariance function : return the covariance between two R.V. for the entire kernel's domain definition.
*/
virtual double covarianceFunction(
double X,
double Y
)const = 0 ;
~Kernel() = default;
};
template <typename Func>
class FooKernel : public Kernel {
public:
FooKernel(Func&& fun) : fun_(std::forward<Func>(fun)) {}
double covarianceFunction(
double X,
double Y
) const {
return fun_(X, Y);
}
template<class T>
auto operator+(const T b) const {
return make_foo_kernel([b, this](double X, double Y) -> double {
return this->covarianceFunction(X, Y) + b.covarianceFunction(X, Y);
});
}
FooKernel operator=(const FooKernel other) const {
return other;
}
private:
Func fun_;
};
template <typename Func>
auto make_foo_kernel(Func&& fun)
{
return FooKernel<Func>(std::forward<Func>(fun));
}
class GaussianKernel : public Kernel {
public:
GaussianKernel(double sigma, double scale) : m_sigma(sigma), m_scale(scale) {}
GaussianKernel(double sigma) : m_sigma(sigma), m_scale(1) {}
/*
A well known covariance function that enforces smooth deformations
Ref : Shape modeling using Gaussian process Morphable Models, Luethi et al.
*/
double covarianceFunction(
double X,
double Y
) const
{
//use diagonal matrix
double result;
result = m_scale * exp(-std::norm(X - Y) / (m_sigma*m_sigma));
return result;
}
template<class T>
auto operator+(const T b) const {
return make_foo_kernel([b, this](double X, double Y) -> double {
auto debugBval = b.covarianceFunction(X, Y);
auto debugAval = this->covarianceFunction(X, Y);
auto test = debugBval + debugAval;
return test;
});
}
private:
double m_sigma;
double m_scale;
};
int main()
{
auto C = GaussianKernel(50,60) + GaussianKernel(100,200);
auto result = C.covarianceFunction(30.0,40.0);
return 0;
}
使用這種設計,使用std::function
的唯一改進是對類進行模板參數化,這可能會產生其他不需要的問題。
template<class Fun>
class FooKernel : public Kernel {
public:
FooKernel(Fun&& fun) : fun_(std::forward<Fun>(fun)) {}
...
private:
Fun fun_;
};
如果你不想模擬你的類,如果你需要你的類擁有一個有狀態的函數對象,那么std::function
幾乎是唯一的方法。
但是,如果您不需要所有權或函數或函數對象是無狀態的(例如自由函數),並且您在問題中聲明我可以為您提供替代選項。
如你所說,你喜歡std::function
的清晰度,你可以嘗試這個非擁有的函數引用類:
#include <utility>
template<typename TSignature> class function_ref;
template<typename TRet, typename ...TParams>
class function_ref<TRet(TParams...)> final
{
using refptr_t = void*;
using callback_t = TRet (*)(refptr_t, TParams&&...);
callback_t m_callback = nullptr;
refptr_t m_callable = nullptr;
public:
constexpr function_ref() noexcept = default;
constexpr function_ref(const function_ref&) noexcept = default;
constexpr function_ref& operator=(const function_ref&) noexcept = default;
constexpr function_ref(function_ref&&) noexcept = default;
constexpr function_ref& operator=(function_ref&&) noexcept = default;
~function_ref() noexcept = default;
template <
typename T,
typename = typename std::enable_if_t<
std::is_invocable_r_v<TRet, T(TParams...), TParams...> &&
!std::is_convertible_v<std::decay_t<T>, function_ref>
>
>
constexpr function_ref(T &&_callable) noexcept :
m_callback(
[](refptr_t callable, TParams&& ...params)
{return (*reinterpret_cast<std::remove_reference_t<T>*>(callable))(std::forward<TParams>(params)...);}
),
m_callable(reinterpret_cast<refptr_t>(std::addressof(_callable)))
{}
constexpr decltype(auto) operator()(TParams&& ...params) noexcept
{
return m_callback(m_callable, std::forward<TParams>(params)...);
}
constexpr operator bool() noexcept { return m_callback; }
};
這沒有std::function
的開銷,因為它不需要擁有可調用的,並且通過我的測試,它通常完全內聯-O3
優化。 這是我在本次演講中對 Vittorio Romeo所討論的課程的修改實現。 您仍然需要觀察傳遞給構造函數的函數的生命周期,但是獲取函數參數是完美的。
用法示例:
void func(int x)
{
std::cout<<x<< " I'm a free func!\n";
}
class Obj
{
public:
void member(int x) { std::cout<<x<< " I'm a member func!\n";}
};
int main()
{
// Define the signature
using func_ref_t = function_ref<void(int)>;
// Can be used with stateful lambdas
int bar = 1;
auto lambda = [&bar](int x){std::cout<<x<< " I'm a lambda!\n"; ++bar;};
// Copy and move
func_ref_t lref(lambda);
auto cpy = lref;
auto mv = std::move(lref);
cpy(1);
mv(2);
// See the modified var from the lambda
std::cout<<bar<<'\n';
// Use with free functions
auto fref = func_ref_t{func};
fref(4);
// We can wrap member functions with stateful lamdas
Obj obj;
auto mem = [&obj](int x) { obj.member(x); };
auto mref = func_ref_t{mem};
mref(5);
}
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.