'How to add custom method to Pyspark Dataframe class by inheritance

I am trying to inherit DataFrame class and add additional custom methods as below so that i can chain fluently and also ensure all methods refers the same dataframe. I get an exception as column is not iterable

from pyspark.sql.dataframe import DataFrame

class Myclass(DataFrame):
def __init__(self,df):
    super().__init__(df._jdf, df.sql_ctx)

def add_column3(self):
 // Add column1 to dataframe received
  self._jdf.withColumn("col3",lit(3))
  return self

def add_column4(self):
 // Add column to dataframe received
  self._jdf.withColumn("col4",lit(4))
  return self

if __name__ == "__main__":
'''
Spark Context initialization code
col1 col2
a 1
b 2
'''
  df = spark.createDataFrame([("a",1), ("b",2)], ["col1","col2"])
  myobj = MyClass(df)
  ## Trying to accomplish below where i can chain MyClass methods & Dataframe methods
  myobj.add_column3().add_column4().drop_columns(["col1"])

'''
Expected Output
col2, col3,col4
1,3,4
2,3,4
'''


Solution 1:[1]

Below is my solution (which is based on your code). I don't know if it's the best practice, but at least does what you want correctly. Dataframes are immutable objects, so after we add a new column we create a new object but not a Dataframe object but a Myclass object, because we want to have Dataframe and custom methods.

from pyspark.sql.dataframe import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()


class MyClass(DataFrame):
   def __init__(self,df):
      super().__init__(df._jdf, df.sql_ctx)
      self._df = df

  def add_column3(self):
      #Add column1 to dataframe received
      newDf=self._df.withColumn("col3",F.lit(3))
      return MyClass(newDf)

  def add_column4(self):
      #Add column2 to dataframe received
      newDf=self._df.withColumn("col4",F.lit(4))
      return MyClass(newDf)


df = spark.createDataFrame([("a",1), ("b",2)], ["col1","col2"])
myobj = MyClass(df)
myobj.add_column3().add_column4().na.drop().show()

# Result:
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|   a|   1|   3|   4|
|   b|   2|   3|   4|
+----+----+----+----+

Solution 2:[2]

Actually you don't need to inherit DataFrame class in order to add some custom methods to DataFrame objects.

In Python, you can add a custom property that wraps your methods like this:

# decorator to attach a function to an attribute
def add_attr(cls):
    def decorator(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            f = func(*args, **kwargs)
            return f

        setattr(cls, func.__name__, _wrapper)
        return func

    return decorator

# custom functions
def custom(self):
    @add_attr(custom)
    def add_column3():
        return self.withColumn("col3", lit(3))

    @add_attr(custom)
    def add_column4():
        return self.withColumn("col4", lit(4))

    return custom

# add new property to the Class pyspark.sql.DataFrame
DataFrame.custom = property(custom)

# use it
df.custom.add_column3().show()

Solution 3:[3]

I think you are looking for something like this:

class dfc:
  def __init__(self, df):
    self.df = df
    
  def func(self, num):
    self.df = self.df.selectExpr(f"id * {num} AS id")
  
  def func1(self, num1):
    self.df = self.df.selectExpr(f"id * {num1} AS id")
    
  def dfdis(self):
    self.df.show()

In this example, there is a dataframe passed to the constructor method which is used by subsequent methods defined inside the class. The state of the dataframe is stored in the instantiated object whenever corresponding methods are called.

df = spark.range(10)

ob = dfc(df)

ob.func(2)

ob.func(2)

ob.dfdis()

Solution 4:[4]

The answer by blackbishop is worth a look, even if it has no upvotes as of this writing. This seems a good general approach for extending the Spark DataFrame class, and I presume other complex objects. I rewrote it slightly as this:

from pyspark.sql.dataframe import DataFrame
from functools import wraps

# Create a decorator to add a function to a python object
def add_attr(cls):
    def decorator(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            f = func(*args, **kwargs)
            return f

        setattr(cls, func.__name__, _wrapper)
        return func

    return decorator

  
# Extensions to the Spark DataFrame class go here
def dataframe_extension(self):
  @add_attr(dataframe_extension)
  def drop_fusion_gdpp_events():
    return(
      self
      .where(~((col('test1') == 'ABC') & (col('test2') =='XYZ')))
      .where(~col('test1').isin(['AAA', 'BBB']))
    )
  return dataframe_extension

DataFrame.dataframe_extension = property(dataframe_extension)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 ggeop
Solution 2 blackbishop
Solution 3
Solution 4