'How to split a list to multiple columns in Pyspark?

I have:

key   value
a    [1,2,3]
b    [2,3,4]

I want:

key value1 value2 value3
a     1      2      3
b     2      3      4

It seems that in scala I can write:df.select($"value._1", $"value._2", $"value._3"), but it is not possible in python.

So is there a good way to do this?



Solution 1:[1]

It depends on the type of your "list":

  • If it is of type ArrayType():

    df = hc.createDataFrame(sc.parallelize([['a', [1,2,3]], ['b', [2,3,4]]]), ["key", "value"])
    df.printSchema()
    df.show()
    root
     |-- key: string (nullable = true)
     |-- value: array (nullable = true)
     |    |-- element: long (containsNull = true)
    

    you can access the values like you would with python using []:

    df.select("key", df.value[0], df.value[1], df.value[2]).show()
    +---+--------+--------+--------+
    |key|value[0]|value[1]|value[2]|
    +---+--------+--------+--------+
    |  a|       1|       2|       3|
    |  b|       2|       3|       4|
    +---+--------+--------+--------+
    
    +---+-------+
    |key|  value|
    +---+-------+
    |  a|[1,2,3]|
    |  b|[2,3,4]|
    +---+-------+
    
  • If it is of type StructType(): (maybe you built your dataframe by reading a JSON)

    df2 = df.select("key", psf.struct(
            df.value[0].alias("value1"), 
            df.value[1].alias("value2"), 
            df.value[2].alias("value3")
        ).alias("value"))
    df2.printSchema()
    df2.show()
    root
     |-- key: string (nullable = true)
     |-- value: struct (nullable = false)
     |    |-- value1: long (nullable = true)
     |    |-- value2: long (nullable = true)
     |    |-- value3: long (nullable = true)
    
    +---+-------+
    |key|  value|
    +---+-------+
    |  a|[1,2,3]|
    |  b|[2,3,4]|
    +---+-------+
    

    you can directly 'split' the column using *:

    df2.select('key', 'value.*').show()
    +---+------+------+------+
    |key|value1|value2|value3|
    +---+------+------+------+
    |  a|     1|     2|     3|
    |  b|     2|     3|     4|
    +---+------+------+------+
    

Solution 2:[2]

I'd like to add the case of sized lists (arrays) to pault answer.

In the case that our column contains medium sized arrays (or large sized ones) it is still possible to split them in columns.

from pyspark.sql.types import *          # Needed to define DataFrame Schema.
from pyspark.sql.functions import expr   

# Define schema to create DataFrame with an array typed column.
mySchema = StructType([StructField("V1", StringType(), True),
                       StructField("V2", ArrayType(IntegerType(),True))])

df = spark.createDataFrame([['A', [1, 2, 3, 4, 5, 6, 7]], 
                            ['B', [8, 7, 6, 5, 4, 3, 2]]], schema= mySchema)

# Split list into columns using 'expr()' in a comprehension list.
arr_size = 7
df = df.select(['V1', 'V2']+[expr('V2[' + str(x) + ']') for x in range(0, arr_size)])

# It is posible to define new column names.
new_colnames = ['V1', 'V2'] + ['val_' + str(i) for i in range(0, arr_size)] 
df = df.toDF(*new_colnames)

The result is:

df.show(truncate= False)

+---+---------------------+-----+-----+-----+-----+-----+-----+-----+
|V1 |V2                   |val_0|val_1|val_2|val_3|val_4|val_5|val_6|
+---+---------------------+-----+-----+-----+-----+-----+-----+-----+
|A  |[1, 2, 3, 4, 5, 6, 7]|1    |2    |3    |4    |5    |6    |7    |
|B  |[8, 7, 6, 5, 4, 3, 2]|8    |7    |6    |5    |4    |3    |2    |
+---+---------------------+-----+-----+-----+-----+-----+-----+-----+

Solution 3:[3]

I needed to unlist a 712 dimensional array into columns in order to write it to csv. I used @MaFF's solution first for my problem but that seemed to cause a lot of errors and additional computation time. I am not sure what was causing it, but I used a different method which reduced the computation time considerably (22 minutes compared to more than 4 hours)!

Method by @MaFF's:

length = len(dataset.head()["list_col"])
dataset = dataset.select(dataset.columns + [dataset["list_col"][k] for k in range(length)])

What I used:

dataset = dataset.rdd.map(lambda x: (*x, *x["list_col"])).toDF()

If someone has any ideas what was causing this difference in computational time, please let me know! I suspect that in my case the bottleneck was with calling head() to get the list length (which I would like be be adaptive). And because (i) my data pipeline was quite long and exhaustive, and (ii) I had to unlist multiple columns. Furthermore caching the entire dataset was not an option.

Solution 4:[4]

For arraytype data, to do it dynamically, you can do something like

df2.select(['key'] + [df2.features[x] for x in range(0,3)])

Solution 5:[5]

@jordi Aceiton thanks for the solution. I tried to make it more concise, tried to remove the loop for renaming the newly created column names, doing it while creating the columns. Using df.columns to fetch all the column names rather creating it manually.

from pyspark.sql.types import *          
from pyspark.sql.functions import * 
from pyspark import Row

df = spark.createDataFrame([Row(index=1, finalArray = [1.1,2.3,7.5], c =4),Row(index=2, finalArray = [9.6,4.1,5.4], c= 4)])
#collecting all the column names as list
dlist = df.columns
#Appending new columns to the dataframe
df.select(dlist+[(col("finalArray")[x]).alias("Value"+str(x+1)) for x in range(0, 3)]).show()

Output:

 +---------------+-----+------+------+------+
 |  finalArray   |index|Value1|Value2|Value3|
 +---------------+-----+------+------+------+
 |[1.1, 2.3, 7.5]|  1  |   1.1|   2.3|   7.5|
 |[9.6, 4.1, 5.4]|  2  |   9.6|   4.1|   5.4|
 +---------------+-----+------+------+------+

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 pault
Solution 2 Jordi Aceiton
Solution 3 thijsvdp
Solution 4 Bhargav Rao
Solution 5 Oli