简体   繁体   中英

Pyspark join dataframe based on function

I have 2 dataframes that look like that

networks

+----------------+-------+
|    Network     | VLAN  |
+----------------+-------+
| 192.168.1.0/24 | VLAN1 |
| 192.168.2.0/24 | VLAN2 |
+----------------+-------+

flows

+--------------+----------------+
|  source_ip   | destination_ip |
+--------------+----------------+
| 192.168.1.11 | 192.168.2.13   |
+--------------+----------------+

Ideally I would like to get something like this

+--------------+----------------+-------------+------------------+
|  source_ip   | destination_ip | source_vlan | destination_vlan |
+--------------+----------------+-------------+------------------+
| 192.168.1.11 | 192.168.2.13   | VLAN1       | VLAN2            |
+--------------+----------------+-------------+------------------+

Unfortunately the flows dataframe does not contain the subnetmask. What I have tried so far without pyspark

  1. Get a distinct list of subnets (in this example [24])
  2. For every subnet and source_ip compute the network ipaddress.ip_network('{}/{}'.format(ip,sub), strict=False)
  3. Search if that subnet exists in dataframe "networks" and return VLAN otherwise return empty string

I tried to do a similar approach with pyspark but it does not work as well as I think there might be better ways of doing it?

def get_available_subnets(df):
    split_col = split(df['network'], '/')
    df = df.withColumn('sub', split_col.getItem(1))
    return df.select('sub').distinct()

def get_vlan_by_ip(ip, infoblox, subnets):
    for sub in subnets:
        net = ipaddress.ip_network('{}/{}'.format(ip,sub), strict=False)
        if net:
            search = infoblox.filter(infoblox.network == str(net))

            if not search.head(1).isEmpty():
                return search.first.vlan
    return hashlib.sha1(str.encode(ip)).hexdigest()

subnets = get_available_subnets(infoblox_networks_df).select('sub').rdd.flatMap(lambda x: x).collect()


short = flows_filtered_prepared_df.limit(1000)


partial_vlan_func = partial(get_vlan_by_ip, infoblox=infoblox_networks_df, subnets=subnets)
get_vlan_udf = udf(lambda ip: partial_vlan_func(ip), StringType())

short.select('source_ip', 'destination_ip', get_vlan_udf('source_ip').alias('source_vlan')).show()

This method completely avoids the use of udf , leveraging split and slice , but perhaps there is a better way. The benefit of this approach is that it directly leverages the bits present in the subnet mask and that it's written purely in PySpark .

Context for the solution: IP addresses can be split and masked by the subnet. This means that 8, 16, 24, 32 tell you which parts of the IP matter - this motivates the division by 8 and using the resulting column to slice the IP address ArrayType column once it's split from its original StringType .

NB: pyspark.sql.functions.slice will work in newer version of PySpark >= 2.4 , some older ones need to use f.expr("slice(...)") .

The setup:

flows = spark.createDataFrame([
    (1, "192.168.1.1", "192.168.2.1"),
    (2, "192.168.2.1", "192.168.3.1"), 
    (3, "192.168.3.1", "192.168.1.1"), 
], ['id', 'source_ip', 'destination_ip'] 
)
networks = spark.createDataFrame([
    (1, "192.168.1.0/24", "VLAN1"),
    (2, "192.168.2.0/24", "VLAN2"), 
    (3, "192.168.3.0/24", "VLAN3"), 
], ['id', 'network', 'vlan'] 
)

Some pre-processing:

networks_split = networks.select(
    "*",
    (f.split(f.col("network"), "/")[1] / 8).cast("int").alias("bits"),
    f.split(f.split(f.col("network"), "/")[0], "\.").alias('segmented_ip')
)
networks_split.show()
+---+--------------+-----+----+----------------+
| id|       network| vlan|bits|    segmented_ip|
+---+--------------+-----+----+----------------+
|  1|192.168.1.0/24|VLAN1|   3|[192, 168, 1, 0]|
|  2|192.168.2.0/24|VLAN2|   3|[192, 168, 2, 0]|
|  3|192.168.3.0/24|VLAN3|   3|[192, 168, 3, 0]|
+---+--------------+-----+----+----------------+

networks_masked = networks_split.select(
    "*",
    f.expr("slice(segmented_ip, 1, bits)").alias("masked_bits"),
)
networks_masked.show()
+---+--------------+-----+----+----------------+-------------+
| id|       network| vlan|bits|    segmented_ip|  masked_bits|
+---+--------------+-----+----+----------------+-------------+
|  1|192.168.1.0/24|VLAN1|   3|[192, 168, 1, 0]|[192, 168, 1]|
|  2|192.168.2.0/24|VLAN2|   3|[192, 168, 2, 0]|[192, 168, 2]|
|  3|192.168.3.0/24|VLAN3|   3|[192, 168, 3, 0]|[192, 168, 3]|
+---+--------------+-----+----+----------------+-------------+

flows_split = flows.select(
    "*",
    f.split(f.split(f.col("source_ip"), "/")[0], "\.").alias('segmented_source_ip'),
    f.split(f.split(f.col("destination_ip"), "/")[0], "\.").alias('segmented_destination_ip')
)
flows_split.show()
+---+-----------+--------------+-------------------+------------------------+
| id|  source_ip|destination_ip|segmented_source_ip|segmented_destination_ip|
+---+-----------+--------------+-------------------+------------------------+
|  1|192.168.1.1|   192.168.2.1|   [192, 168, 1, 1]|        [192, 168, 2, 1]|
|  2|192.168.2.1|   192.168.3.1|   [192, 168, 2, 1]|        [192, 168, 3, 1]|
|  3|192.168.3.1|   192.168.1.1|   [192, 168, 3, 1]|        [192, 168, 1, 1]|
+---+-----------+--------------+-------------------+------------------------+

Finally, I crossJoin and filter on the slice based on the bits of my mask, such as:

flows_split.crossJoin(
    networks_masked.select("vlan", "bits", "masked_bits")
).where(
    f.expr("slice(segmented_source_ip, 1, bits)") == f.col("masked_bits")
).show()
+---+-----------+--------------+-------------------+------------------------+-----+----+-------------+
| id|  source_ip|destination_ip|segmented_source_ip|segmented_destination_ip| vlan|bits|  masked_bits|
+---+-----------+--------------+-------------------+------------------------+-----+----+-------------+
|  1|192.168.1.1|   192.168.2.1|   [192, 168, 1, 1]|        [192, 168, 2, 1]|VLAN1|   3|[192, 168, 1]|
|  2|192.168.2.1|   192.168.3.1|   [192, 168, 2, 1]|        [192, 168, 3, 1]|VLAN2|   3|[192, 168, 2]|
|  3|192.168.3.1|   192.168.1.1|   [192, 168, 3, 1]|        [192, 168, 1, 1]|VLAN3|   3|[192, 168, 3]|
+---+-----------+--------------+-------------------+------------------------+-----+----+-------------+

Exactly the same approach can be done for destination_ip , such as:

flows_split.crossJoin(
    networks_masked.select("vlan", "bits", "masked_bits")
).where(
    f.expr("slice(segmented_destination_ip, 1, bits)") == f.col("masked_bits")
).show()
+---+-----------+--------------+-------------------+------------------------+-----+----+-------------+
| id|  source_ip|destination_ip|segmented_source_ip|segmented_destination_ip| vlan|bits|  masked_bits|
+---+-----------+--------------+-------------------+------------------------+-----+----+-------------+
|  1|192.168.1.1|   192.168.2.1|   [192, 168, 1, 1]|        [192, 168, 2, 1]|VLAN2|   3|[192, 168, 2]|
|  2|192.168.2.1|   192.168.3.1|   [192, 168, 2, 1]|        [192, 168, 3, 1]|VLAN3|   3|[192, 168, 3]|
|  3|192.168.3.1|   192.168.1.1|   [192, 168, 3, 1]|        [192, 168, 1, 1]|VLAN1|   3|[192, 168, 1]|
+---+-----------+--------------+-------------------+------------------------+-----+----+-------------+

Finally, you either join the resulting two tables together on source_ip and destination_ip (since you have the vlan information attached as required), or you merge the previous two steps together and crossJoin and filter twice.

Hope this helps!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM