'How to extract column names from SQL query using Python

I would like to extract the column names of a resulting table directly from the SQL statement:


query = """

select 
    sales.order_id as id, 
    p.product_name, 
    sum(p.price) as sales_volume 
from sales
right join products as p 
    on sales.product_id=p.product_id
group by id, p.product_name;

"""

column_names = parse_sql(query)
# column_names:
# ['id', 'product_name', 'sales_volume']

Any idea what to do in parse_sql()? The resulting function should be able to recognize aliases and remove the table aliases/identifiers (e.g. "sales." or "p.").

Thanks in advance!



Solution 1:[1]

I've done something like this using the library sqlparse. Basically, this library takes your SQL query and tokenizes it. Once that is done, you can search for the select query token and parse the underlying tokens. In code, that reads like

import sqlparse
def find_selected_columns(query) -> list[str]:
    tokens = sqlparse.parse(query)[0].tokens
    found_select = False
    for token in tokens:
        if found_select:
            if isinstance(token, sqlparse.sql.IdentifierList):
                return [
                    col.value.split(" ")[-1].strip("`").rpartition('.')[-1]
                    for col in token.tokens
                    if isinstance(col, sqlparse.sql.Identifier)
                ]
        else:
            found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"])
    raise Exception("Could not find a select statement. Weired query :)")

This code should also work for queries with Common table expressions, i.e. it only return the final select columns. Depending on the SQL dialect and the quote chars you are using, you might to have to adapt the line col.value.split(" ")[-1].strip("`").rpartition('.')[-1]

Solution 2:[2]

Try out SQLGlot

It's much easier and less error prone than sqlparse.

import sqlglot
import sqlglot.expressions as exp

query = """
select
    sales.order_id as id,
    p.product_name,
    sum(p.price) as sales_volume
from sales
right join products as p
    on sales.product_id=p.product_id
group by id, p.product_name;

"""

column_names = []

for expression in sqlglot.parse_one(query).find(exp.Select).args["expressions"]:
    if isinstance(expression, exp.Alias):
        column_names.append(expression.text("alias"))
    elif isinstance(expression, exp.Column):
        column_names.append(expression.text("this"))

print(column_names)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Simon Hawe
Solution 2 Toby Mao