Apache Spark

PySpark – collect_list() and collect_set()

In Python, PySpark is a Spark module used to provide a similar kind of processing like spark using DataFrame.

We will discuss collect_list() and collect_set() functions in PySpark DataFrame.

Before moving to these functions, we will create PySpark DataFrame

Example:

Here we are going to create PySpark dataframe with 5 rows and 6 columns.

#import the pyspark module
import pyspark
#import SparkSession for creating a session
from pyspark.sql import SparkSession

#create an app named linuxhint
spark_app = SparkSession.builder.appName('linuxhint').getOrCreate()

# create student data with 5 rows and 6 attributes
students1 =[{'rollno':'001','name':'sravan','age':23,'height':5.79,'weight':67,'address':'guntur'},
               {'rollno':'002','name':'ojaswi','age':16,'height':3.79,'weight':34,'address':'hyd'},
               {'rollno':'003','name':'gnanesh chowdary','age':7,'height':2.79,'weight':17,'address':'patna'},
               {'rollno':'004','name':'rohith','age':9,'height':2.79,'weight':28,'address':'hyd'},
               {'rollno':'005','name':'sridevi','age':9,'height':5.59,'weight':54,'address':'hyd'}]

# create the dataframe
df = spark_app.createDataFrame( students1)

# display dataframe
df.show()

Output:

PySpark – collect_list()

collect_list() method is used to get the data from the PySpark DataFrame columns and return the values in Row format. It will return all values along with duplicates. But we need to import this method from pyspark.sql.functions module.

We can use this method to display the collected data in the form of a Row.

Syntax:

dataframe.select(“collect_list(“column”))

where:

  1. dataframe is the input PySpark DataFrame
  2. column is the column name where collect_list() is applied

Example 1:

In this example, we are collecting data from address column and display the values with collect() method.

#import the pyspark module
import pyspark
#import SparkSession for creating a session
from pyspark.sql import SparkSession
#import collect_list function
from pyspark.sql.functions import collect_list

#create an app named linuxhint
spark_app = SparkSession.builder.appName('linuxhint').getOrCreate()

# create student data with 5 rows and 6 attributes
students1 =[{'rollno':'001','name':'sravan','age':23,'height':5.79,'weight':67,'address':'guntur'},
               {'rollno':'002','name':'ojaswi','age':16,'height':3.79,'weight':34,'address':'hyd'},
               {'rollno':'003','name':'gnanesh chowdary','age':7,'height':2.79,'weight':17,'address':'patna'},
               {'rollno':'004','name':'rohith','age':9,'height':3.69,'weight':28,'address':'hyd'},
               {'rollno':'005','name':'sridevi','age':37,'height':5.59,'weight':54,'address':'hyd'}]

# create the dataframe
df = spark_app.createDataFrame( students1)

# collect_list on address
df.select(collect_list("address")).collect()

Output:

[Row(collect_list(address)=[‘guntur’, ‘hyd’, ‘patna’, ‘hyd’, ‘hyd’])]

Example 2:

In this example, we are collecting data from height and weight columns and display the values with collect() method.

#import the pyspark module
import pyspark
#import SparkSession for creating a session
from pyspark.sql import SparkSession
#import collect_list function
from pyspark.sql.functions import collect_list

#create an app named linuxhint
spark_app = SparkSession.builder.appName('linuxhint').getOrCreate()

# create student data with 5 rows and 6 attributes
students1 =[{'rollno':'001','name':'sravan','age':23,'height':5.79,'weight':67,'address':'guntur'},
               {'rollno':'002','name':'ojaswi','age':16,'height':3.79,'weight':34,'address':'hyd'},
               {'rollno':'003','name':'gnanesh chowdary','age':7,'height':2.79,'weight':17,'address':'patna'},
               {'rollno':'004','name':'rohith','age':9,'height':3.69,'weight':28,'address':'hyd'},
               {'rollno':'005','name':'sridevi','age':37,'height':5.59,'weight':54,'address':'hyd'}]

# create the dataframe
df = spark_app.createDataFrame( students1)

# collect_list on height and weight columns
df.select(collect_list("height"),collect_list("weight")).collect()

Output:

[Row(collect_list(height)=[5.79, 3.79, 2.79, 3.69, 5.59], collect_list(weight)=[67, 34, 17, 28, 54])]

PySpark – collect_set()

collect_set() method is used to get the data from the PySpark DataFrame columns and return the values in Row format. It will return all values without duplicates.

We can use this method to display the collected data in the form of a row. But we need to import this method from pyspark.sql.functions module.

Syntax:

dataframe.select(“collect_set(“column”))

where:

  1. dataframe is the input PySpark DataFrame
  2. column is the column name where collect_list() is applied

Example 1:

In this example, we are collecting data from address column and display the values with collect() method.

#import the pyspark module
import pyspark
#import SparkSession for creating a session
from pyspark.sql import SparkSession
#import collect_set function
from pyspark.sql.functions import collect_set

#create an app named linuxhint
spark_app = SparkSession.builder.appName('linuxhint').getOrCreate()

# create student data with 5 rows and 6 attributes
students1 =[{'rollno':'001','name':'sravan','age':23,'height':5.79,'weight':67,'address':'guntur'},
               {'rollno':'002','name':'ojaswi','age':16,'height':3.79,'weight':34,'address':'hyd'},
               {'rollno':'003','name':'gnanesh chowdary','age':7,'height':2.79,'weight':17,'address':'patna'},
               {'rollno':'004','name':'rohith','age':9,'height':3.69,'weight':28,'address':'hyd'},
               {'rollno':'005','name':'sridevi','age':37,'height':5.59,'weight':54,'address':'hyd'}]

# create the dataframe
df = spark_app.createDataFrame( students1)

# collect_set on address
df.select(collect_set("address")).collect()

Output:

[Row(collect_set(address)=[‘hyd’, ‘guntur’, ‘patna’])]

Example 2:

In this example, we are collecting data from height and weight columns and display the values with collect() method.

#import the pyspark module
import pyspark
#import SparkSession for creating a session
from pyspark.sql import SparkSession
#import collect_set function
from pyspark.sql.functions import collect_set

#create an app named linuxhint
spark_app = SparkSession.builder.appName('linuxhint').getOrCreate()

# create student data with 5 rows and 6 attributes
students1 =[{'rollno':'001','name':'sravan','age':23,'height':5.79,'weight':67,'address':'guntur'},
               {'rollno':'002','name':'ojaswi','age':16,'height':3.79,'weight':34,'address':'hyd'},
               {'rollno':'003','name':'gnanesh chowdary','age':7,'height':2.79,'weight':17,'address':'patna'},
               {'rollno':'004','name':'rohith','age':9,'height':3.69,'weight':28,'address':'hyd'},
               {'rollno':'005','name':'sridevi','age':37,'height':5.59,'weight':54,'address':'hyd'}]

# create the dataframe
df = spark_app.createDataFrame( students1)

# collect_set on height and weight columns
df.select(collect_set("height"),collect_list("weight")).collect()

Output:

[Row(collect_set(height)=[5.59, 3.69, 2.79, 5.79, 3.79], collect_list(weight)=[67, 34, 17, 28, 54])]

Conclusion

We have seen that collect_list() and collect_set() methods are used to get the data from a column in PySpark DataFrame. From these methods, we observed that collect_set() will not allow duplicates but collect_list() allow duplicate values.

About the author

Gottumukkala Sravan Kumar

B tech-hon's in Information Technology; Known programming languages - Python, R , PHP MySQL; Published 500+ articles on computer science domain