简体   繁体   中英

How to get the table name from Spark SQL Query [PySpark]?

To get the table name from a SQL Query,

select *
from table1 as t1
full outer join table2 as t2
  on t1.id = t2.id

I found a solution in Scala How to get table names from SQL query?

def getTables(query: String): Seq[String] = {
  val logicalPlan = spark.sessionState.sqlParser.parsePlan(query)
  import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
  logicalPlan.collect { case r: UnresolvedRelation => r.tableName }
}

which gives me the correct table names when I iterate over the return sequence getTables(query).foreach(println)

table1
table2

What would be the equivalent syntax for PySpark? The closest I came across was How to extract column name and column type from SQL in pyspark

plan = spark_session._jsparkSession.sessionState().sqlParser().parsePlan(query)
print(f"table: {plan.tableDesc().identifier().table()}")

which fails with the traceback

Py4JError: An error occurred while calling o78.tableDesc. Trace:
py4j.Py4JException: Method tableDesc([]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
    at py4j.Gateway.invoke(Gateway.java:274)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.base/java.lang.Thread.run(Thread.java:835)

I understand, the problem stems up from the fact that I need to filter all plan items which are of type UnresolvedRelation but I cannot find an equivalent notation in python/pyspark

I have an approach, but rather convoluted. It dumps the Java Object and JSON (a poor man's serialization process), deserializes it to python object, filter and parse the table names

import json
def get_tables(query: str):
    plan = spark._jsparkSession.sessionState().sqlParser().parsePlan(query)
    plan_items = json.loads(plan.toJSON())
    for plan_item in plan_items:
        if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
            yield plan_item['tableIdentifier']['table']

which yields ['fast_track_gv_nexus', 'buybox_gv_nexus'] when I iterate over the function list(get_tables(query))

Note Unfortunately, this breaks for CTE

Example

with delta as (
   select *
    group by id
    cluster by id
 )
select   *
  from ( select  *
         FROM
          (select   *
            from dmm
            inner join delta on dmm.id = delta.id
           )
  )

And to resolve it, I have to hack around through regular expression

import json
import re
def get_tables(query: str):
    plan = spark._jsparkSession.sessionState().sqlParser().parsePlan(query)
    plan_items = json.loads(plan.toJSON())
    plan_string = plan.toString()
    cte = re.findall(r"CTE \[(.*?)\]", plan_string)
    for plan_item in plan_items:
        if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
            tableIdentifier = plan_item['tableIdentifier']
            table =  plan_item['tableIdentifier']['table']   
            database =  tableIdentifier.get('database', '')
            table_name = "{}.{}".format(database, table) if database else table
            if table_name not in cte:
                yield table_name

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM