spark-optimization by wshobson/agents
npx skills add https://github.com/wshobson/agents --skill spark-optimization优化 Apache Spark 作业的生产模式,包括分区策略、内存管理、Shuffle 优化和性能调优。
Driver Program
↓
Job (triggered by action)
↓
Stages (separated by shuffles)
↓
Tasks (one per partition)
| 因素 | 影响 | 解决方案 |
|---|---|---|
| Shuffle | 网络 I/O,磁盘 I/O | 最小化宽依赖转换 |
| 数据倾斜 | 任务执行时间不均 | 加盐,广播连接 |
| 序列化 | CPU 开销 | 使用 Kryo,列式存储格式 |
广告位招租
在这里展示您的产品或服务
触达数万 AI 开发者,精准高效
| 内存 | GC 压力,溢出 | 调整执行器内存 |
| 分区 | 并行度 | 合理设置分区大小 |
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
# 创建优化的 Spark 会话
spark = (SparkSession.builder
.appName("OptimizedJob")
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.sql.adaptive.skewJoin.enabled", "true")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.shuffle.partitions", "200")
.getOrCreate())
# 使用优化设置读取数据
df = (spark.read
.format("parquet")
.option("mergeSchema", "false")
.load("s3://bucket/data/"))
# 高效转换
result = (df
.filter(F.col("date") >= "2024-01-01")
.select("id", "amount", "category")
.groupBy("category")
.agg(F.sum("amount").alias("total")))
result.write.mode("overwrite").parquet("s3://bucket/output/")
# 计算最优分区数量
def calculate_partitions(data_size_gb: float, partition_size_mb: int = 128) -> int:
"""
最优分区大小:128MB - 256MB
过少:利用率不足,内存压力大
过多:任务调度开销大
"""
return max(int(data_size_gb * 1024 / partition_size_mb), 1)
# 为均匀分布进行重分区
df_repartitioned = df.repartition(200, "partition_key")
# 合并以减少分区(无 Shuffle)
df_coalesced = df.coalesce(100)
# 使用谓词下推进行分区剪裁
df = (spark.read.parquet("s3://bucket/data/")
.filter(F.col("date") == "2024-01-01")) # Spark 将此谓词下推
# 为未来查询使用分区写入
(df.write
.partitionBy("year", "month", "day")
.mode("overwrite")
.parquet("s3://bucket/partitioned_output/"))
from pyspark.sql import functions as F
from pyspark.sql.types import *
# 1. 广播连接 - 小表连接
# 最佳时机:一侧 < 10MB(可配置)
small_df = spark.read.parquet("s3://bucket/small_table/") # < 10MB
large_df = spark.read.parquet("s3://bucket/large_table/") # TBs
# 显式广播提示
result = large_df.join(
F.broadcast(small_df),
on="key",
how="left"
)
# 2. 排序合并连接 - 大表的默认方式
# 需要 Shuffle,但可处理任意大小
result = large_df1.join(large_df2, on="key", how="inner")
# 3. 分桶连接 - 预排序,连接时无 Shuffle
# 写入分桶表
(df.write
.bucketBy(200, "customer_id")
.sortBy("customer_id")
.mode("overwrite")
.saveAsTable("bucketed_orders"))
# 连接分桶表(无 Shuffle!)
orders = spark.table("bucketed_orders")
customers = spark.table("bucketed_customers") # 相同的分桶数量
result = orders.join(customers, on="customer_id")
# 4. 倾斜连接处理
# 启用 AQE 倾斜连接优化
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
# 针对严重倾斜的手动加盐
def salt_join(df_skewed, df_other, key_col, num_salts=10):
"""添加盐值以分发倾斜键"""
# 为倾斜侧添加盐值
df_salted = df_skewed.withColumn(
"salt",
(F.rand() * num_salts).cast("int")
).withColumn(
"salted_key",
F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
)
# 使用所有盐值展开另一侧
df_exploded = df_other.crossJoin(
spark.range(num_salts).withColumnRenamed("id", "salt")
).withColumn(
"salted_key",
F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
)
# 在加盐键上连接
return df_salted.join(df_exploded, on="salted_key", how="inner")
from pyspark import StorageLevel
# 当多次复用 DataFrame 时进行缓存
df = spark.read.parquet("s3://bucket/data/")
df_filtered = df.filter(F.col("status") == "active")
# 缓存在内存中(MEMORY_AND_DISK 是默认值)
df_filtered.cache()
# 或使用特定的存储级别
df_filtered.persist(StorageLevel.MEMORY_AND_DISK_SER)
# 强制物化
df_filtered.count()
# 在多个操作中使用
agg1 = df_filtered.groupBy("category").count()
agg2 = df_filtered.groupBy("region").sum("amount")
# 完成后取消持久化
df_filtered.unpersist()
# 存储级别说明:
# MEMORY_ONLY - 快速,但可能放不下
# MEMORY_AND_DISK - 需要时溢出到磁盘(推荐)
# MEMORY_ONLY_SER - 序列化,内存占用少,CPU 开销大
# DISK_ONLY - 内存紧张时使用
# OFF_HEAP - Tungsten 堆外内存
# 为复杂血缘关系设置检查点
spark.sparkContext.setCheckpointDir("s3://bucket/checkpoints/")
df_complex = (df
.join(other_df, "key")
.groupBy("category")
.agg(F.sum("amount")))
df_complex.checkpoint() # 打断血缘关系,物化数据
# 执行器内存配置
# spark-submit --executor-memory 8g --executor-cores 4
# 内存分解(8GB 执行器):
# - spark.memory.fraction = 0.6 (60% = 4.8GB 用于执行 + 存储)
# - spark.memory.storageFraction = 0.5 (4.8GB 的 50% = 2.4GB 用于缓存)
# - 剩余 2.4GB 用于执行(Shuffle、连接、排序)
# - 40% = 3.2GB 用于用户数据结构和内部元数据
spark = (SparkSession.builder
.config("spark.executor.memory", "8g")
.config("spark.executor.memoryOverhead", "2g") # 用于非 JVM 内存
.config("spark.memory.fraction", "0.6")
.config("spark.memory.storageFraction", "0.5")
.config("spark.sql.shuffle.partitions", "200")
# 针对内存密集型操作
.config("spark.sql.autoBroadcastJoinThreshold", "50MB")
# 防止大型 Shuffle 时发生 OOM
.config("spark.sql.files.maxPartitionBytes", "128MB")
.getOrCreate())
# 监控内存使用情况
def print_memory_usage(spark):
"""打印当前内存使用情况"""
sc = spark.sparkContext
for executor in sc._jsc.sc().getExecutorMemoryStatus().keySet().toArray():
mem_status = sc._jsc.sc().getExecutorMemoryStatus().get(executor)
total = mem_status._1() / (1024**3)
free = mem_status._2() / (1024**3)
print(f"{executor}: {total:.2f}GB total, {free:.2f}GB free")
# 减少 Shuffle 数据大小
spark.conf.set("spark.sql.shuffle.partitions", "auto") # 使用 AQE
spark.conf.set("spark.shuffle.compress", "true")
spark.conf.set("spark.shuffle.spill.compress", "true")
# Shuffle 前进行预聚合
df_optimized = (df
# 先进行本地聚合(合并器)
.groupBy("key", "partition_col")
.agg(F.sum("value").alias("partial_sum"))
# 然后进行全局聚合
.groupBy("key")
.agg(F.sum("partial_sum").alias("total")))
# 通过 Map 端操作避免 Shuffle
# 不好:为每个 distinct 操作进行 Shuffle
distinct_count = df.select("category").distinct().count()
# 好:近似 distinct(无 Shuffle)
approx_count = df.select(F.approx_count_distinct("category")).collect()[0][0]
# 减少分区时使用 coalesce 而不是 repartition
df_reduced = df.coalesce(10) # 无 Shuffle
# 使用压缩优化 Shuffle
spark.conf.set("spark.io.compression.codec", "lz4") # 快速压缩
# Parquet 优化
(df.write
.option("compression", "snappy") # 快速压缩
.option("parquet.block.size", 128 * 1024 * 1024) # 128MB 行组
.parquet("s3://bucket/output/"))
# 列剪裁 - 仅读取需要的列
df = (spark.read.parquet("s3://bucket/data/")
.select("id", "amount", "date")) # Spark 仅读取这些列
# 谓词下推 - 在存储层进行过滤
df = (spark.read.parquet("s3://bucket/partitioned/year=2024/")
.filter(F.col("status") == "active")) # 下推到 Parquet 读取器
# Delta Lake 优化
(df.write
.format("delta")
.option("optimizeWrite", "true") # 装箱打包
.option("autoCompact", "true") # 压缩小文件
.mode("overwrite")
.save("s3://bucket/delta_table/"))
# 用于多维查询的 Z-Ordering
spark.sql("""
OPTIMIZE delta.`s3://bucket/delta_table/`
ZORDER BY (customer_id, date)
""")
# 启用详细指标
spark.conf.set("spark.sql.codegen.wholeStage", "true")
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
# 解释查询计划
df.explain(mode="extended")
# 模式:simple, extended, codegen, cost, formatted
# 获取物理计划统计信息
df.explain(mode="cost")
# 监控任务指标
def analyze_stage_metrics(spark):
"""分析最近阶段的指标"""
status_tracker = spark.sparkContext.statusTracker()
for stage_id in status_tracker.getActiveStageIds():
stage_info = status_tracker.getStageInfo(stage_id)
print(f"Stage {stage_id}:")
print(f" Tasks: {stage_info.numTasks}")
print(f" Completed: {stage_info.numCompletedTasks}")
print(f" Failed: {stage_info.numFailedTasks}")
# 识别数据倾斜
def check_partition_skew(df):
"""检查分区倾斜"""
partition_counts = (df
.withColumn("partition_id", F.spark_partition_id())
.groupBy("partition_id")
.count()
.orderBy(F.desc("count")))
partition_counts.show(20)
stats = partition_counts.select(
F.min("count").alias("min"),
F.max("count").alias("max"),
F.avg("count").alias("avg"),
F.stddev("count").alias("stddev")
).collect()[0]
skew_ratio = stats["max"] / stats["avg"]
print(f"倾斜比例:{skew_ratio:.2f}x (>2x 表示存在倾斜)")
# 生产配置模板
spark_configs = {
# 自适应查询执行 (AQE)
"spark.sql.adaptive.enabled": "true",
"spark.sql.adaptive.coalescePartitions.enabled": "true",
"spark.sql.adaptive.skewJoin.enabled": "true",
# 内存
"spark.executor.memory": "8g",
"spark.executor.memoryOverhead": "2g",
"spark.memory.fraction": "0.6",
"spark.memory.storageFraction": "0.5",
# 并行度
"spark.sql.shuffle.partitions": "200",
"spark.default.parallelism": "200",
# 序列化
"spark.serializer": "org.apache.spark.serializer.KryoSerializer",
"spark.sql.execution.arrow.pyspark.enabled": "true",
# 压缩
"spark.io.compression.codec": "lz4",
"spark.shuffle.compress": "true",
# 广播
"spark.sql.autoBroadcastJoinThreshold": "50MB",
# 文件处理
"spark.sql.files.maxPartitionBytes": "128MB",
"spark.sql.files.openCostInBytes": "4MB",
}
.count() 检查存在性 - 使用 .take(1) 或 .isEmpty()每周安装量
3.3K
仓库
GitHub 星标数
32.2K
首次出现
Jan 20, 2026
安全审计
安装于
claude-code2.5K
gemini-cli2.4K
opencode2.4K
cursor2.4K
codex2.3K
github-copilot2.1K
Production patterns for optimizing Apache Spark jobs including partitioning strategies, memory management, shuffle optimization, and performance tuning.
Driver Program
↓
Job (triggered by action)
↓
Stages (separated by shuffles)
↓
Tasks (one per partition)
| Factor | Impact | Solution |
|---|---|---|
| Shuffle | Network I/O, disk I/O | Minimize wide transformations |
| Data Skew | Uneven task duration | Salting, broadcast joins |
| Serialization | CPU overhead | Use Kryo, columnar formats |
| Memory | GC pressure, spills | Tune executor memory |
| Partitions | Parallelism | Right-size partitions |
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
# Create optimized Spark session
spark = (SparkSession.builder
.appName("OptimizedJob")
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.sql.adaptive.skewJoin.enabled", "true")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.shuffle.partitions", "200")
.getOrCreate())
# Read with optimized settings
df = (spark.read
.format("parquet")
.option("mergeSchema", "false")
.load("s3://bucket/data/"))
# Efficient transformations
result = (df
.filter(F.col("date") >= "2024-01-01")
.select("id", "amount", "category")
.groupBy("category")
.agg(F.sum("amount").alias("total")))
result.write.mode("overwrite").parquet("s3://bucket/output/")
# Calculate optimal partition count
def calculate_partitions(data_size_gb: float, partition_size_mb: int = 128) -> int:
"""
Optimal partition size: 128MB - 256MB
Too few: Under-utilization, memory pressure
Too many: Task scheduling overhead
"""
return max(int(data_size_gb * 1024 / partition_size_mb), 1)
# Repartition for even distribution
df_repartitioned = df.repartition(200, "partition_key")
# Coalesce to reduce partitions (no shuffle)
df_coalesced = df.coalesce(100)
# Partition pruning with predicate pushdown
df = (spark.read.parquet("s3://bucket/data/")
.filter(F.col("date") == "2024-01-01")) # Spark pushes this down
# Write with partitioning for future queries
(df.write
.partitionBy("year", "month", "day")
.mode("overwrite")
.parquet("s3://bucket/partitioned_output/"))
from pyspark.sql import functions as F
from pyspark.sql.types import *
# 1. Broadcast Join - Small table joins
# Best when: One side < 10MB (configurable)
small_df = spark.read.parquet("s3://bucket/small_table/") # < 10MB
large_df = spark.read.parquet("s3://bucket/large_table/") # TBs
# Explicit broadcast hint
result = large_df.join(
F.broadcast(small_df),
on="key",
how="left"
)
# 2. Sort-Merge Join - Default for large tables
# Requires shuffle, but handles any size
result = large_df1.join(large_df2, on="key", how="inner")
# 3. Bucket Join - Pre-sorted, no shuffle at join time
# Write bucketed tables
(df.write
.bucketBy(200, "customer_id")
.sortBy("customer_id")
.mode("overwrite")
.saveAsTable("bucketed_orders"))
# Join bucketed tables (no shuffle!)
orders = spark.table("bucketed_orders")
customers = spark.table("bucketed_customers") # Same bucket count
result = orders.join(customers, on="customer_id")
# 4. Skew Join Handling
# Enable AQE skew join optimization
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
# Manual salting for severe skew
def salt_join(df_skewed, df_other, key_col, num_salts=10):
"""Add salt to distribute skewed keys"""
# Add salt to skewed side
df_salted = df_skewed.withColumn(
"salt",
(F.rand() * num_salts).cast("int")
).withColumn(
"salted_key",
F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
)
# Explode other side with all salts
df_exploded = df_other.crossJoin(
spark.range(num_salts).withColumnRenamed("id", "salt")
).withColumn(
"salted_key",
F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
)
# Join on salted key
return df_salted.join(df_exploded, on="salted_key", how="inner")
from pyspark import StorageLevel
# Cache when reusing DataFrame multiple times
df = spark.read.parquet("s3://bucket/data/")
df_filtered = df.filter(F.col("status") == "active")
# Cache in memory (MEMORY_AND_DISK is default)
df_filtered.cache()
# Or with specific storage level
df_filtered.persist(StorageLevel.MEMORY_AND_DISK_SER)
# Force materialization
df_filtered.count()
# Use in multiple actions
agg1 = df_filtered.groupBy("category").count()
agg2 = df_filtered.groupBy("region").sum("amount")
# Unpersist when done
df_filtered.unpersist()
# Storage levels explained:
# MEMORY_ONLY - Fast, but may not fit
# MEMORY_AND_DISK - Spills to disk if needed (recommended)
# MEMORY_ONLY_SER - Serialized, less memory, more CPU
# DISK_ONLY - When memory is tight
# OFF_HEAP - Tungsten off-heap memory
# Checkpoint for complex lineage
spark.sparkContext.setCheckpointDir("s3://bucket/checkpoints/")
df_complex = (df
.join(other_df, "key")
.groupBy("category")
.agg(F.sum("amount")))
df_complex.checkpoint() # Breaks lineage, materializes
# Executor memory configuration
# spark-submit --executor-memory 8g --executor-cores 4
# Memory breakdown (8GB executor):
# - spark.memory.fraction = 0.6 (60% = 4.8GB for execution + storage)
# - spark.memory.storageFraction = 0.5 (50% of 4.8GB = 2.4GB for cache)
# - Remaining 2.4GB for execution (shuffles, joins, sorts)
# - 40% = 3.2GB for user data structures and internal metadata
spark = (SparkSession.builder
.config("spark.executor.memory", "8g")
.config("spark.executor.memoryOverhead", "2g") # For non-JVM memory
.config("spark.memory.fraction", "0.6")
.config("spark.memory.storageFraction", "0.5")
.config("spark.sql.shuffle.partitions", "200")
# For memory-intensive operations
.config("spark.sql.autoBroadcastJoinThreshold", "50MB")
# Prevent OOM on large shuffles
.config("spark.sql.files.maxPartitionBytes", "128MB")
.getOrCreate())
# Monitor memory usage
def print_memory_usage(spark):
"""Print current memory usage"""
sc = spark.sparkContext
for executor in sc._jsc.sc().getExecutorMemoryStatus().keySet().toArray():
mem_status = sc._jsc.sc().getExecutorMemoryStatus().get(executor)
total = mem_status._1() / (1024**3)
free = mem_status._2() / (1024**3)
print(f"{executor}: {total:.2f}GB total, {free:.2f}GB free")
# Reduce shuffle data size
spark.conf.set("spark.sql.shuffle.partitions", "auto") # With AQE
spark.conf.set("spark.shuffle.compress", "true")
spark.conf.set("spark.shuffle.spill.compress", "true")
# Pre-aggregate before shuffle
df_optimized = (df
# Local aggregation first (combiner)
.groupBy("key", "partition_col")
.agg(F.sum("value").alias("partial_sum"))
# Then global aggregation
.groupBy("key")
.agg(F.sum("partial_sum").alias("total")))
# Avoid shuffle with map-side operations
# BAD: Shuffle for each distinct
distinct_count = df.select("category").distinct().count()
# GOOD: Approximate distinct (no shuffle)
approx_count = df.select(F.approx_count_distinct("category")).collect()[0][0]
# Use coalesce instead of repartition when reducing partitions
df_reduced = df.coalesce(10) # No shuffle
# Optimize shuffle with compression
spark.conf.set("spark.io.compression.codec", "lz4") # Fast compression
# Parquet optimizations
(df.write
.option("compression", "snappy") # Fast compression
.option("parquet.block.size", 128 * 1024 * 1024) # 128MB row groups
.parquet("s3://bucket/output/"))
# Column pruning - only read needed columns
df = (spark.read.parquet("s3://bucket/data/")
.select("id", "amount", "date")) # Spark only reads these columns
# Predicate pushdown - filter at storage level
df = (spark.read.parquet("s3://bucket/partitioned/year=2024/")
.filter(F.col("status") == "active")) # Pushed to Parquet reader
# Delta Lake optimizations
(df.write
.format("delta")
.option("optimizeWrite", "true") # Bin-packing
.option("autoCompact", "true") # Compact small files
.mode("overwrite")
.save("s3://bucket/delta_table/"))
# Z-ordering for multi-dimensional queries
spark.sql("""
OPTIMIZE delta.`s3://bucket/delta_table/`
ZORDER BY (customer_id, date)
""")
# Enable detailed metrics
spark.conf.set("spark.sql.codegen.wholeStage", "true")
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
# Explain query plan
df.explain(mode="extended")
# Modes: simple, extended, codegen, cost, formatted
# Get physical plan statistics
df.explain(mode="cost")
# Monitor task metrics
def analyze_stage_metrics(spark):
"""Analyze recent stage metrics"""
status_tracker = spark.sparkContext.statusTracker()
for stage_id in status_tracker.getActiveStageIds():
stage_info = status_tracker.getStageInfo(stage_id)
print(f"Stage {stage_id}:")
print(f" Tasks: {stage_info.numTasks}")
print(f" Completed: {stage_info.numCompletedTasks}")
print(f" Failed: {stage_info.numFailedTasks}")
# Identify data skew
def check_partition_skew(df):
"""Check for partition skew"""
partition_counts = (df
.withColumn("partition_id", F.spark_partition_id())
.groupBy("partition_id")
.count()
.orderBy(F.desc("count")))
partition_counts.show(20)
stats = partition_counts.select(
F.min("count").alias("min"),
F.max("count").alias("max"),
F.avg("count").alias("avg"),
F.stddev("count").alias("stddev")
).collect()[0]
skew_ratio = stats["max"] / stats["avg"]
print(f"Skew ratio: {skew_ratio:.2f}x (>2x indicates skew)")
# Production configuration template
spark_configs = {
# Adaptive Query Execution (AQE)
"spark.sql.adaptive.enabled": "true",
"spark.sql.adaptive.coalescePartitions.enabled": "true",
"spark.sql.adaptive.skewJoin.enabled": "true",
# Memory
"spark.executor.memory": "8g",
"spark.executor.memoryOverhead": "2g",
"spark.memory.fraction": "0.6",
"spark.memory.storageFraction": "0.5",
# Parallelism
"spark.sql.shuffle.partitions": "200",
"spark.default.parallelism": "200",
# Serialization
"spark.serializer": "org.apache.spark.serializer.KryoSerializer",
"spark.sql.execution.arrow.pyspark.enabled": "true",
# Compression
"spark.io.compression.codec": "lz4",
"spark.shuffle.compress": "true",
# Broadcast
"spark.sql.autoBroadcastJoinThreshold": "50MB",
# File handling
"spark.sql.files.maxPartitionBytes": "128MB",
"spark.sql.files.openCostInBytes": "4MB",
}
.count() for existence - Use .take(1) or .isEmpty()Weekly Installs
3.3K
Repository
GitHub Stars
32.2K
First Seen
Jan 20, 2026
Security Audits
Gen Agent Trust HubPassSocketPassSnykPass
Installed on
claude-code2.5K
gemini-cli2.4K
opencode2.4K
cursor2.4K
codex2.3K
github-copilot2.1K
React 组合模式指南:Vercel 组件架构最佳实践,提升代码可维护性
102,200 周安装