使用 pyspark 处理数据的基本流程

总流程

读数据

读数据前,要考虑两个方面:数据的外观(对外呈现给我们的样子)和内在(具体存放的位置)。具体来说,要考虑:

  • 表头
  • schema
  • 存储格式(尽量用 parquet,列式存储,压缩性能,读取快)
  • 数据量
df = spark.read.csv("s3://mybucket/large_dataset.csv", header=True, inferSchema=True)
df = spark.read.parquet("s3://mybucket/large_dataset/")

读完数据后,要查看(探索)数据

print(f"Total Rows: {df.count()}")
df.printSchema()
df.show(5)

数据清洗

这个阶段就涉对数据内容的操作了,可以通过函数算子的方式(记不住的话可以通过官网查看),也可以通过 sql 的方式。

from pyspark.sql.functions import col, when, avg
# Drop null values
df = df.na.drop()
# Filter out invalid entries
df = df.filter(col("age") > 0)
# Add a calculated column
df = df.withColumn("income_per_person", col("total_income") / col("family_size"))

# Replace missing values with mean
mean_val = df.select(avg("salary")).collect()[0][0]
df = df.na.fill({"salary": mean_val})

数据计算

这个阶段也是对数据内容的操作,只是数据的内容和格式已经比较规范。这个阶段要注意对资源的利用,时刻通过 sparkUI 关注计算效率。尽量把 task 分配到不同 exector 上,最大化利用集群的力量。

region_stats = df.groupBy("region") \
    .agg(avg("income_per_person").alias("avg_income")) \
    .orderBy(col("avg_income").desc())
region_stats.show(10)


这一阶段可能会遇到数据倾斜的问题。
对于小表
from pyspark.sql.functions import broadcast
# Broadcast join for small tables
df = large_df.join(broadcast(small_df), "customer_id")

对于大表
df1 = df1.repartition("customer_id")
df2 = df2.repartition("customer_id")
joined_df = df1.join(df2, "customer_id")

数据保存

一般保存成表或者文件

df.write.mode("overwrite").parquet("s3://mybucket/processed_data/")
df.write.format("jdbc").option("url", "jdbc:mysql://...").save()

数据可视化

import matplotlib.pyplot as plt
sample = df.limit(10000).toPandas()
sample['income_per_person'].hist(bins=50)
plt.title("Income Distribution")
plt.show()