- Introduction
- Setting Up PySpark Environment
- Creating DataFrames
- Exploring DataFrames
- Data Selection and Filtering
- Data Transformation
- Aggregations
- Window Functions
- Joins
- Set Operations
- Sorting and Ordering
- User-Defined Functions (UDFs)
- Writing Data
- Performance Optimization
- Common Patterns and Recipes
- Tips for Working with Large Datasets
This comprehensive guide covers essential operations for data processing tasks using PySpark's DataFrame API, organized for efficient reference when working with large datasets.
# Import necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql.window import Window
# Create a SparkSession
spark = SparkSession.builder \
.appName("Your Application Name") \
.config("spark.some.config.option", "some-value") \
.getOrCreate()
# Set log level (optional)
spark.sparkContext.setLogLevel("WARN")# From existing RDD
rdd = spark.sparkContext.parallelize([(1, "John"), (2, "Jane")])
df = spark.createDataFrame(rdd, ["id", "name"])
# From Lists
data = [(1, "John", 30), (2, "Jane", 25)]
df = spark.createDataFrame(data, ["id", "name", "age"])
# From Pandas DataFrame
import pandas as pd
pandas_df = pd.DataFrame({"id": [1, 2], "name": ["John", "Jane"]})
df = spark.createDataFrame(pandas_df)
# Reading from files
df_csv = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)
df_json = spark.read.json("path/to/file.json")
df_parquet = spark.read.parquet("path/to/file.parquet")
df_orc = spark.read.orc("path/to/file.orc")
df_text = spark.read.text("path/to/file.txt")
# Reading from database
df_jdbc = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql://host:port/database") \
.option("dbtable", "schema.table") \
.option("user", "username") \
.option("password", "password") \
.load()# Basic DataFrame inspection
df.show() # Display first 20 rows
df.show(n=5, truncate=False) # Show 5 rows without truncating
df.printSchema() # Print schema
df.count() # Number of rows
df.columns # List column names
df.dtypes # List column names and data types
# DataFrame statistics
df.describe().show() # Statistical summary
df.summary().show() # More detailed summary
# Sampling
sample_df = df.sample(fraction=0.1, seed=42)# Select columns
df.select("name", "age").show()
df.select(df["name"], df["age"] + 1).show() # With expressions
# Using column expressions
from pyspark.sql.functions import col
df.select(col("name"), col("age") + 1).show()
# Filter rows
df.filter(df["age"] > 25).show()
df.filter("age > 25").show() # SQL expression
df.where(df["age"] > 25).show() # Alternative to filter
# Handling NULL values
df.filter(df["age"].isNull()).show()
df.filter(df["age"].isNotNull()).show()
# Distinct values
df.select("department").distinct().show()
df.dropDuplicates(["name", "department"]).show() # Distinct based on subset of columns
# Limiting results
df.limit(10).show()# Adding new columns
df = df.withColumn("age_plus_one", df["age"] + 1)
df = df.withColumn("adult", F.when(df["age"] >= 18, "Yes").otherwise("No"))
# Renaming columns
df = df.withColumnRenamed("age", "years")
# Dropping columns
df = df.drop("age_plus_one")
# Casting data types
df = df.withColumn("age", df["age"].cast(T.IntegerType()))
# Working with strings
df = df.withColumn("upper_name", F.upper(df["name"]))
df = df.withColumn("name_length", F.length(df["name"]))
df = df.withColumn("first_char", F.substring(df["name"], 1, 1))
df = df.withColumn("trimmed", F.trim(df["name"]))
df = df.withColumn("concatenated", F.concat(df["name"], F.lit(" - "), df["department"]))
# Date functions
df = df.withColumn("current_date", F.current_date())
df = df.withColumn("current_timestamp", F.current_timestamp())
df = df.withColumn("date_add", F.date_add(df["date_col"], 1))
df = df.withColumn("year", F.year(df["date_col"]))
df = df.withColumn("month", F.month(df["date_col"]))
df = df.withColumn("day", F.dayofmonth(df["date_col"]))
# Handling nulls
df = df.withColumn("cleaned_col", F.coalesce(df["col1"], df["col2"], F.lit("default")))
df = df.na.fill(0, ["age"]) # Fill nulls with 0 in age column
df = df.na.fill({"age": 0, "name": "Unknown"}) # Fill multiple columns
df = df.na.drop() # Drop rows with any null values
df = df.na.drop(subset=["age", "name"]) # Drop rows with nulls in specific columns# Simple aggregations
df.select(F.max("age"), F.min("age"), F.avg("age")).show()
# GroupBy operations
df.groupBy("department").count().show()
df.groupBy("department").agg(F.avg("salary").alias("avg_salary")).show()
# Multiple aggregations
df.groupBy("department").agg(
F.count("*").alias("count"),
F.sum("salary").alias("total_salary"),
F.avg("salary").alias("avg_salary"),
F.min("salary").alias("min_salary"),
F.max("salary").alias("max_salary")
).show()
# Common aggregation functions
# F.sum(), F.count(), F.avg(), F.min(), F.max(), F.countDistinct(),
# F.sumDistinct(), F.stddev(), F.variance()
# Pivot tables
df.groupBy("department").pivot("country").sum("salary").show()# Define a window specification
windowSpec = Window.partitionBy("department").orderBy("salary")
# Rank employees by salary within each department
df = df.withColumn("rank", F.rank().over(windowSpec))
df = df.withColumn("dense_rank", F.dense_rank().over(windowSpec))
df = df.withColumn("row_number", F.row_number().over(windowSpec))
# Calculate cumulative and moving aggregations
df = df.withColumn("cumulative_salary", F.sum("salary").over(windowSpec))
df = df.withColumn("moving_avg", F.avg("salary").over(
Window.partitionBy("department").orderBy("date").rowsBetween(-2, 0)
))
# Get values from previous/next rows
df = df.withColumn("prev_salary", F.lag("salary", 1).over(windowSpec))
df = df.withColumn("next_salary", F.lead("salary", 1).over(windowSpec))# Different join types
joined_df = df1.join(df2, df1["id"] == df2["id"], "inner") # Inner join
joined_df = df1.join(df2, df1["id"] == df2["id"], "left") # Left outer join
joined_df = df1.join(df2, df1["id"] == df2["id"], "right") # Right outer join
joined_df = df1.join(df2, df1["id"] == df2["id"], "full") # Full outer join
joined_df = df1.join(df2, df1["id"] == df2["id"], "leftsemi") # Left semi join
joined_df = df1.join(df2, df1["id"] == df2["id"], "leftanti") # Left anti join
# Join on multiple conditions
joined_df = df1.join(df2,
(df1["id"] == df2["id"]) & (df1["dept"] == df2["dept"]),
"inner")
# Cross join / Cartesian product
cross_joined_df = df1.crossJoin(df2)# Union (combines and keeps duplicates)
union_df = df1.union(df2)
# Union by name (matches columns by name, not position)
union_df = df1.unionByName(df2)
# Union by name with schema resolution
union_df = df1.unionByName(df2, allowMissingColumns=True)
# Intersect (common rows)
intersect_df = df1.intersect(df2)
# Except/Minus (rows in df1 but not in df2)
except_df = df1.exceptAll(df2)# Sort by a single column
sorted_df = df.orderBy("age")
sorted_df = df.orderBy(df["age"].asc()) # Ascending
sorted_df = df.orderBy(df["age"].desc()) # Descending
# Sort by multiple columns
sorted_df = df.orderBy("department", F.col("salary").desc())# Define a Python function
def celsius_to_fahrenheit(celsius):
return (celsius * 9/5) + 32
# Register UDF
celsius_to_fahrenheit_udf = F.udf(celsius_to_fahrenheit, T.DoubleType())
# Apply UDF
df = df.withColumn("temp_fahrenheit", celsius_to_fahrenheit_udf(df["temp_celsius"]))
# Pandas UDFs (vectorized UDFs - much faster)
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
@pandas_udf(DoubleType())
def pandas_celsius_to_fahrenheit(celsius: pd.Series) -> pd.Series:
return (celsius * 9/5) + 32
df = df.withColumn("temp_fahrenheit", pandas_celsius_to_fahrenheit(df["temp_celsius"]))# Save as Parquet (recommended for performance)
df.write.parquet("path/to/output.parquet")
# Save as CSV
df.write.csv("path/to/output.csv", header=True)
# Save as JSON
df.write.json("path/to/output.json")
# Save to database
df.write \
.format("jdbc") \
.option("url", "jdbc:postgresql://host:port/database") \
.option("dbtable", "schema.table") \
.option("user", "username") \
.option("password", "password") \
.mode("append") # or "overwrite", "ignore", "error"
.save()
# Partition by columns (for better query performance)
df.write.partitionBy("year", "month").parquet("path/to/partitioned_output")
# Save with specific options
df.write \
.format("parquet") \
.mode("overwrite") \
.option("compression", "snappy") \
.save("path/to/output_with_options")# Cache a DataFrame in memory
df.cache()
# Or with storage level specification
from pyspark.storagelevel import StorageLevel
df.persist(StorageLevel.MEMORY_AND_DISK)
# Release from memory
df.unpersist()
# Repartition (increase/decrease partitions)
df_repartitioned = df.repartition(10)
# Repartition by specific columns
df_repartitioned = df.repartition("department", "country")
# Coalesce (only reduces partitions - more efficient than repartition)
df_coalesced = df.coalesce(5)
# Broadcast join for small tables
from pyspark.sql.functions import broadcast
joined_df = df1.join(broadcast(df2), df1["id"] == df2["id"])# Salting technique for skewed keys
from pyspark.sql.functions import rand
num_salts = 10
# Add salt to the skewed column
df_salted = df.withColumn("salt", (F.rand() * num_salts).cast("int"))
# Join with salt
salted_join = df_salted.join(
other_df,
df_salted["key"] == other_df["key"]
)# Working with arrays
df = df.withColumn("array_column", F.split(df["string_col"], ","))
df = df.withColumn("array_length", F.size(df["array_column"]))
df = df.withColumn("contains_item", F.array_contains(df["array_column"], "item"))
df = df.withColumn("first_item", df["array_column"].getItem(0))
df = df.withColumn("sorted_array", F.sort_array(df["array_column"]))
# Explode arrays (one row per array element)
df_exploded = df.select("id", F.explode("array_column").alias("single_item"))
# Explode with position
df_exploded = df.select("id", F.posexplode("array_column").alias("pos", "single_item"))
# Working with maps
df = df.withColumn("map_values", F.create_map(
F.lit("key1"), F.col("value1"),
F.lit("key2"), F.col("value2")
))
df = df.withColumn("key1_value", df["map_values"].getItem("key1"))
df = df.withColumn("map_keys", F.map_keys(df["map_values"]))
df = df.withColumn("map_values_only", F.map_values(df["map_values"]))
# Explode maps
df_exploded = df.select("id", F.explode("map_values").alias("key", "value"))
# Working with structs
df = df.withColumn("struct_col", F.struct("col1", "col2", "col3"))
df = df.withColumn("extracted_field", df["struct_col"]["col1"])# Calculate time difference between events
df = df.withColumn(
"time_diff_seconds",
F.unix_timestamp(df["current_time"]) - F.unix_timestamp(df["previous_time"])
)
# Group time into bins
df = df.withColumn(
"hour_bin",
F.date_trunc("hour", df["timestamp"])
)
# Moving averages
windowSpec = Window.partitionBy("id").orderBy("timestamp").rowsBetween(-3, 0)
df = df.withColumn("moving_avg_4_periods", F.avg("value").over(windowSpec))
# Lag/lead calculations
windowSpec = Window.partitionBy("id").orderBy("timestamp")
df = df.withColumn("prev_value", F.lag("value", 1).over(windowSpec))
df = df.withColumn("change", df["value"] - df["prev_value"])
df = df.withColumn("pct_change", (df["value"] - df["prev_value"]) / df["prev_value"] * 100)# Handling exceptions in UDFs
def safe_divide(a, b):
try:
return a / b
except:
return None # Return None for errors
safe_divide_udf = F.udf(safe_divide, T.DoubleType())
# Validate data with when/otherwise
df = df.withColumn(
"status",
F.when(df["age"] < 0, "Invalid age")
.when(df["salary"] < 0, "Invalid salary")
.otherwise("Valid")
)
# Get a sample of problematic records
problem_records = df.filter("status != 'Valid'").limit(10)
problem_records.show()
# Write execution plan for debugging
print(df.explain()) # Logical and physical plans
print(df.explain(True)) # Extended explanation- Use appropriate file formats (Parquet or ORC) instead of CSV/JSON for better compression and performance
- Partition your data sensibly based on frequent query patterns
- Use broadcast joins for small dataframes to avoid shuffling
- Reduce the number of transformations by combining operations
- Persist/cache intermediate results that are used multiple times
- Use appropriate data types (e.g., using IntegerType instead of StringType for numeric data)
- Consider repartitioning if your partitions are very skewed
- Use explain() to understand and optimize execution plans
Remember that PySpark is lazy in evaluation - transformations are not executed until an action (like show(), count(), or collect()) is called. This allows Spark to optimize the execution plan.