繁体   English   中英

Julia 成对广播

[英]Julia pairwise broadcast

我想比较 Julia 中字符串列表中的每一对字符串。 一种方法是

equal_strs = [(x == y) for x in str_list, y in str_list]

但是,如果我按如下方式使用广播:

equal_strs = broadcast(==, str_list, str_list)

它返回一个向量而不是二维数组。 有没有办法使用broadcast output 二维数组?

广播通过扩展(“广播”)长度不同的维度来工作,其方式是使用 NxKx1 广播的(例如)大小为 Nx1xM 的数组给出 NxKxM 数组。

这意味着如果你广播一个长度为 N 个向量的操作,你将得到一个长度为 N 个向量。

所以你需要一个字符串数组是一个长度为 N 的向量,另一个是 1xM 矩阵:

julia> using Random

julia> str1 = [randstring('A':'C', 3) for _ in 1:5]
5-element Vector{String}:
 "ACC"
 "CBC"
 "AAC"
 "CAB"
 "BAB"

1.8.0> str2 = [randstring('A':'C', 3) for _ in 1:4]
4-element Vector{String}:
 "ABB"
 "BAB"
 "CAA"
 "BBC"

1.8.0> str1 .== permutedims(str2)
5×4 BitMatrix:
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  1  0  0

permutedims会将长度为 N 的向量更改为 1xN 矩阵。

顺便说一句,您很少会在代码中使用broadcastbroadcast(==, a, b) ),而是使用点语法, a.== b ,这更惯用。

您应该为广播机器转置一个向量,以通过扩展输入的维度来构建矩阵以达成一致。

julia> str_list = ["hello", "car", "me", "good", "people", "good"];

julia> equal_strs = broadcast(==, str_list, permutedims(str_list))
6×6 BitMatrix:
 1  0  0  0  0  0
 0  1  0  0  0  0
 0  0  1  0  0  0
 0  0  0  1  0  1
 0  0  0  0  1  0
 0  0  0  1  0  1

此外,以下类似。

equal_strs = str_list .== permutedims(str_list)
equal_strs = isequal.(str_list, permutedims(str_list))

将假设“列表”是指向量,因为 Julia 中没有类似 python 的列表。 如果您的意思是一个元组,我建议将其转换为 Vector ,因为广播最好与 Arrays (Vector 是其子类型)一起使用。

str_list = ["one", "two", "three", "one", "two"]

现在你只需做

broadcast(==, str_list, permutedims(str_list))

或更简洁的点运算符

str_list .== permutedims(str_list)

引擎盖下会发生什么:

在 Julia 中广播按元素工作,所以如果你有 2 个向量,它不会做任何事情,因为尺寸匹配。

但是如果你有一个向量和一个矩阵(向量是一个一维数组,矩阵是一个二维数组) ,形状为(N,1)(1,N) Julia 将广播1维给你一个形状矩阵(N,N)这就是你想要的。

现在通常用数字你会做'而不是permutedims

num_list .== num_list'

至于为什么它不适用于字符串,请参阅此答案

正如其他答案所建议的那样, lst.== permutedims(lst)是一种非常好的查找结果的方法。 但它需要 O(n^2) 比较,如果列表很长,使用 O(n*log(n)) 比较算法可能会更好。 以下是带有一点基准的算法的实现:

function equal_str(lst)
    sp = sortperm(lst)
    isp = invperm(sp)
    same = [i==1 ? false : lst[sp[i]]==lst[sp[i-1]] for i=1:length(lst)]
    ac = accumulate((k,v)-> ifelse(v==false, k+1, k), same; init=0)
    return [ ac[isp[i]]==ac[isp[j]] for i=1:length(lst),j=1:length(lst) ]
end

基准给出:

julia> using Random

julia> using BenchmarkTools

julia> lst = [randstring('A':'C',3) for i=1:40];

julia> show(lst)
["CBA", "CAB", "BCA", "AAC", "AAA", "ABC", "BBA", "CAB", "CBC", "CCA",
 "BCC", "BCB", "CAB", "BCB", "ACC", "CBC", "CCC", "CCB", "BCB", "BCB", 
 "ABA", "AAC", "CCC", "ABC", "BAC", "CAB", "BAB", "BCB", "CCA", "CAC", 
 "AAA", "BBC", "ABC", "BCB", "CBA", "CAA", "CAB", "CAC", "CBC", "CBC"]

julia> @btime $lst .== permutedims($lst) ;
  9.025 μs (5 allocations: 4.58 KiB)

julia> @btime equal_str($lst) ;
  6.112 μs (8 allocations: 3.08 KiB)

lst越大,差异越大。 正如 OP 所建议的那样,这仅适用于将列表与其自身进行比较。 为了比较两个列表,应该在 O(n*log(n)) 时间内使用不同的算法。

最后,即使这个算法通过排序也有点太难了,但是 O(n^2) 时间/空间复杂度是产生结果的内在因素。

更新:更线性的 O(n) 时间计算(仍然 O(n^2) 来制作矩阵):

function equal_str_2(lst)
    d = Dict{String,Int}()
    d2 = Dict{Int, Vector{Int}}()
    for p in pairs(lst)
        if haskey(d,p[2])
            push!(d2[d[p[2]]],p[1])
        else
            d[p[2]] = p[1]
            d2[p[1]] = [p[1]]
        end
    end
    res = zeros(Bool, (length(lst), length(lst)))
    for p in values(d2)
        for q in Iterators.product(p,p)
            res[q[1],q[2]] = true
            res[q[2], q[1]] = true
        end
    end
    return res
end

并使用更大的lst进行基准测试:

julia> lst = [randstring('A':'C',3) for i=1:140];

julia> @btime $lst .== permutedims($lst) ;
  99.094 μs (5 allocations: 6.89 KiB)

julia> @btime equal_str($lst) ;
  51.981 μs (9 allocations: 23.12 KiB)

julia> @btime equal_str_2($lst) ;
  21.539 μs (72 allocations: 27.47 KiB)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM