简体   繁体   English

16位浮点MPI_Reduce?

[英]16-bit float MPI_Reduce?

I have a distributed application that uses MPI_Reduce() for some of the communication. 我有一个分布式应用程序,它使用MPI_Reduce()进行某些通信。 In terms of precision, we get fully accurate results with 16-bit float numbers (half-precision). 在精度方面,我们使用16位浮点数(半精度)获得完全准确的结果。

To accelerate the communication (reducing the amount of data movement), is there a way to call MPI_Reduce() on 16-bit float numbers? 为了加速通信(减少数据移动量),有没有办法在16位浮点数上调用MPI_Reduce()


(I looked at the MPI documentation and didn't see any info about 16-bit floats.) (我查看了MPI文档,但没有看到有关16位浮点数的任何信息。)

The MPI standard defines only 32-bit ( MPI_FLOAT ) or 64-bit ( MPI_DOUBLE ) floats in its internal datatypes. MPI标准在其内部数据类型中仅定义了32位( MPI_FLOAT )或64位( MPI_DOUBLE )浮点数。

However, you can always create your own MPI_Datatype and your own custom reduce operation. 但是,您始终可以创建自己的MPI_Datatype和自己的自定义reduce操作。 The code below gives some rough idea of how you can do this. 下面的代码粗略地说明了如何做到这一点。 Since it is unclear which 16 bit float implementation you are using, I'm going to refer to the type simply as float16_t and the addition operation as fp16_add() . 由于不清楚你正在使用哪个16位浮点实现,我将简单地将类型称为float16_t ,将加法操作称为fp16_add()

// define custom reduce operation
void my_fp16_sum(void* invec, void* inoutvec, int *len,
              MPI_Datatype *datatype) {
    // cast invec and inoutvec to your float16 type
    float16_t* in = (float16_t)invec;
    float16_t* inout = (float16_t)inoutvec;
    for (int i = 0; i < *len; ++i) {
        // sum your 16 bit floats
        *inout = fp16_add(*in, *inout);
    }
}

// ...

//  in your code:

// create 2-byte datatype (send raw, un-interpreted bytes)
MPI_Datatype mpi_type_float16;
MPI_Type_contiguous(2, MPI_BYTE, &mpi_type_float16);
MPI_Type_commit(&mpi_type_float16);

// create user op (pass function pointer to your user function)
MPI_Op mpi_fp16sum;
MPI_Op_create(&my_fp16_sum, 1, &mpi_fp16sum);

// call MPI_Reduce using your custom reduction operation
MPI_Reduce(&fp16_val, &fp16_result, 1, mpi_type_float16, mpi_fp16sum, 0, MPI_COMM_WORLD);

// clean up (freeing of the custom MPI_Op and MPI_Datatype)
MPI_Type_free(&mpi_type_float16);
MPI_Op_free(&mpi_fp16sum);

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM