简体   繁体   English

RcppParallel:RMatrix和RVector算术运算

[英]RcppParallel: RMatrix and RVector arithmetic operations

I am trying to parallelize a double for loop with RcppArmadillo , but I am having trouble with the kind of arithmetic operations that are available to RMatrix and RVector . 我试图用RcppArmadillo并行化一个双循环,但是我遇到了RMatrixRVector可用的算术运算RVector I looked at the header file available on github, and I don't see anything there, so I guess I am looking in the wrong place. 我看了github上可用的头文件 ,我没有看到任何东西,所以我想我在错误的地方看。 This is my worker, and I commented where I am trying to do arithmetic operations between two RMatrix objects. 这是我的工作者,我评论了我在两个RMatrix对象之间进行算术运算的RMatrix

#include <RcppParallel.h>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <Rmath.h>
#include <RcppArmadillo.h>
using namespace RcppParallel;


struct ClosestMean : public Worker {

  // Input data and means matrix
  const RMatrix<double> input_data;
  const RMatrix<double> means;

  // Output labels
  RVector<int> predicted_labels;

  // constructor
  ClosestMean(const Rcpp::NumericMatrix input_data, const Rcpp::NumericMatrix means, Rcpp::IntegerVector predicted_labels)
    : input_data(input_data), means(means), predicted_labels(predicted_labels) {}

  // function call operator for the specified range (begin/end)
  void operator () (std::size_t begin, std::size_t end){
    for (unsigned int i = begin; i < end; i++){

      // Check for User Interrupts
      Rcpp::checkUserInterrupt();

      // Get the label corresponding to the cluster mean
      // for which the point is closest to
      RMatrix<double>::Row point = input_data.row(i);
      int label_min = -1;
      double dist;
      double min_dist = INFINITY;

      for (unsigned int j = 0; j < means.nrow(); j++){
        RMatrix<double>::Row mean = means.row(j);
        dist = sqrt(Rcpp::sum((mean - point)^2)); // This is where the operation is failing
        if (dist < min_dist){
          min_dist = dist;
          label_min = j;
        }
      }

      predicted_labels[i] = label_min;

    }
  }

};

Thanks for any advice. 谢谢你的建议。

Basically, you can't subtract two Row objects like you might do with regular Rcpp vectors (ie, taking advantage of so-called Rcpp sugar ) -- it's just not implemented for the RcppParallel wrappers. 基本上,你不能像使用常规Rcpp向量那样减去两个Row对象(即利用所谓的Rcpp糖 ) - 它只是没有为RcppParallel包装器实现。 You'll have to write the iteration yourself. 你必须自己编写迭代。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM