PySpark高效数据切分: 按ID分组抽取时间序列片段
2025-01-16 20:21:46
PySpark 数据切分:按ID分组抽取时间序列数据片段
在处理大规模时间序列数据时,经常会遇到需要根据特定ID分组,并对每组数据按时间或步骤进行抽样的情况。本文探讨如何使用PySpark高效解决此类问题,核心需求为:对每个ID,保留特定步骤的数据(如前三个步骤)并从指定步骤(如第四步)中抽取时间靠前的部分数据,达到更精细化的数据切割。
问题分析
核心挑战在于:如何对 PySpark DataFrame 中不同 ID 的数据进行独立切片,并最终整合。简单的循环遍历是不被推荐的,因为它会将 Spark 的并行处理能力削弱,并造成效率低下。另外一个问题是,在每个ID的第四步数据中抽取指定时间段内的数据时,如何实现精准抽取?一种常用的方法是先计算出第四步数据的事件时间百分位数阈值,再进行过滤。我们关注如何在一个更全局的层面完成所有ID的抽取工作,而不是针对单个ID处理。
解决方案一:使用 Window 函数和过滤
该方法借助PySpark 的 Window 函数,计算每个ID中第四步的事件时间百分位数,之后利用 filter
函数进行数据分割。这种方式能够在一个Spark任务中并行处理所有ID,充分利用计算资源。
步骤如下:
- 定义 Window,基于ID分组。
- 过滤出每个ID中第四步的数据。
- 计算每个ID第四步数据的时间百分位数。
- 合并前三个步骤的数据以及过滤后的第四步数据。
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, percentile_approx, expr, when
spark = SparkSession.builder.appName("DataPartition").getOrCreate()
# 创建示例DataFrame,替换为你的真实数据
data = [
(1456942945, 'var_a', 123.4, 1, 'id_1'),
(1456931076, 'var_b', 857.01, 1, 'id_1'),
(1456932268, 'var_b', 871.74, 1, 'id_1'),
(1456940055, 'var_b', 992.3, 2, 'id_1'),
(1456932781, 'var_c', 861.3, 2, 'id_1'),
(1456937186, 'var_c', 959.6, 3, 'id_1'),
(1456934746, 'var_d', 0.12, 4, 'id_1'),
(1456944444, 'var_e', 0.14, 4, 'id_1'),
(1456946777, 'var_e', 0.23, 4, 'id_1'),
(1456942945, 'var_a', 123.4, 1, 'id_2'),
(1456931076, 'var_b', 847.01, 1, 'id_2'),
(1456932268, 'var_b', 871.74, 1, 'id_2'),
(1456940055, 'var_b', 932.3, 2, 'id_2'),
(1456932781, 'var_c', 821.3, 3, 'id_2'),
(1456937186, 'var_c', 969.6, 4, 'id_2'),
(1456939879, 'var_f', 96.2, 4, 'id_2'),
(1456934746, 'var_d', 0.12, 4, 'id_2')
]
schema = ["event_time", "variable", "value", "step", "ID"]
df = spark.createDataFrame(data, schema)
window_spec = Window.partitionBy("ID")
threshold_percentile = 0.25
# 计算每个ID的事件时间分位数
df_with_threshold = df.withColumn(
"threshold",
when(col("step") == 4,
percentile_approx("event_time", threshold_percentile).over(window_spec))
).otherwise(None)
# 使用计算出来的阈值过滤第4步数据并合并
filtered_df = df_with_threshold.filter(
((col("step") < 4) )| ((col("step") == 4) & (col("event_time") <= col("threshold")))
).drop("threshold")
filtered_df.show()
该代码首先通过 Window.partitionBy("ID")
定义了一个基于 ID 的窗口。然后,针对第四步数据,计算其 25% 分位数阈值。 最后,我们合并前三步和过滤后的第四步数据,得到所需的结果。这种方法直接利用了PySpark的窗口函数特性,性能高效。
解决方案二:使用 groupBy
和 flatMap
此方案使用了 groupBy
对 ID 分组,并配合 flatMap
进行精细数据操作。它对每个ID先进行单独处理,然后把结果统一成一个新的DataFrame。相比第一种方法,逻辑上更加清晰易懂。
操作步骤:
- 对 DataFrame 按照 ID 进行分组。
- 针对每个 ID,进行步骤数据筛选以及时间范围的数据抽取。
- 将所有 ID 的操作结果使用
flatMap
进行扁平化处理并得到新的 DataFrame。
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, percentile_approx, collect_list
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType
spark = SparkSession.builder.appName("DataPartition").getOrCreate()
# 创建示例 DataFrame
data = [
(1456942945, 'var_a', 123.4, 1, 'id_1'),
(1456931076, 'var_b', 857.01, 1, 'id_1'),
(1456932268, 'var_b', 871.74, 1, 'id_1'),
(1456940055, 'var_b', 992.3, 2, 'id_1'),
(1456932781, 'var_c', 861.3, 2, 'id_1'),
(1456937186, 'var_c', 959.6, 3, 'id_1'),
(1456934746, 'var_d', 0.12, 4, 'id_1'),
(1456944444, 'var_e', 0.14, 4, 'id_1'),
(1456946777, 'var_e', 0.23, 4, 'id_1'),
(1456942945, 'var_a', 123.4, 1, 'id_2'),
(1456931076, 'var_b', 847.01, 1, 'id_2'),
(1456932268, 'var_b', 871.74, 1, 'id_2'),
(1456940055, 'var_b', 932.3, 2, 'id_2'),
(1456932781, 'var_c', 821.3, 3, 'id_2'),
(1456937186, 'var_c', 969.6, 4, 'id_2'),
(1456939879, 'var_f', 96.2, 4, 'id_2'),
(1456934746, 'var_d', 0.12, 4, 'id_2')
]
schema = ["event_time", "variable", "value", "step", "ID"]
df = spark.createDataFrame(data, schema)
threshold_percentile = 0.25
def partition_data(rows):
id_value = rows[0]["ID"]
all_data = [row.asDict() for row in rows]
first_steps = [d for d in all_data if d['step'] in [1, 2, 3]]
step4_data = [d for d in all_data if d['step'] == 4]
if step4_data:
# 这里避免当没有 step=4 的时候 报错
event_times_step4 = [row["event_time"] for row in step4_data]
if len(event_times_step4) > 0 :
threshold = sorted(event_times_step4)[int(len(event_times_step4)* threshold_percentile) -1 ] # avoid error while no enough data
else:
threshold = 0
filtered_step4 = [d for d in step4_data if d["event_time"] <= threshold]
return first_steps + filtered_step4
else:
return first_steps
grouped_rdd = df.groupBy("ID").agg(collect_list(struct(*schema))).rdd.flatMap(lambda x : partition_data(x[1]) ) # must specify which schema to unpack here , but i'm using raw sql to create this schema
output_df = spark.createDataFrame(grouped_rdd)
output_df.show()
此方法将 DataFrame 按照 ID 分组,并应用 flatMap
函数对每一组执行数据处理,此过程更加直观,容易理解。它直接操作 DataFrame 的内部结构,实现了数据切割。 需要注意的是,为避免因输入数据的改变造成的Schema 不匹配的问题, 需要通过 schema
来确定 RDD 的schema信息,从而进行 createDataFrame()
操作。
安全建议
- 在实际应用中,数据切片和抽取操作应做好异常处理,应对数据质量问题和空值情况。
- 避免直接操作大表中的
id
列数据,可以通过数据采样或者其它方式预先确认,确保数据量可控。
通过以上方案,可以有效且高效地利用PySpark对DataFrame进行灵活切片。根据实际需求和数据特点,可以选择适合的方案来解决问题。