返回

PySpark高效数据切分: 按ID分组抽取时间序列片段

python

PySpark 数据切分:按ID分组抽取时间序列数据片段

在处理大规模时间序列数据时,经常会遇到需要根据特定ID分组,并对每组数据按时间或步骤进行抽样的情况。本文探讨如何使用PySpark高效解决此类问题,核心需求为:对每个ID,保留特定步骤的数据(如前三个步骤)并从指定步骤(如第四步)中抽取时间靠前的部分数据,达到更精细化的数据切割。

问题分析

核心挑战在于:如何对 PySpark DataFrame 中不同 ID 的数据进行独立切片,并最终整合。简单的循环遍历是不被推荐的,因为它会将 Spark 的并行处理能力削弱,并造成效率低下。另外一个问题是,在每个ID的第四步数据中抽取指定时间段内的数据时,如何实现精准抽取?一种常用的方法是先计算出第四步数据的事件时间百分位数阈值,再进行过滤。我们关注如何在一个更全局的层面完成所有ID的抽取工作,而不是针对单个ID处理。

解决方案一:使用 Window 函数和过滤

该方法借助PySpark 的 Window 函数,计算每个ID中第四步的事件时间百分位数,之后利用 filter 函数进行数据分割。这种方式能够在一个Spark任务中并行处理所有ID,充分利用计算资源。

步骤如下:

  1. 定义 Window,基于ID分组。
  2. 过滤出每个ID中第四步的数据。
  3. 计算每个ID第四步数据的时间百分位数。
  4. 合并前三个步骤的数据以及过滤后的第四步数据。
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, percentile_approx, expr, when

spark = SparkSession.builder.appName("DataPartition").getOrCreate()

# 创建示例DataFrame,替换为你的真实数据
data = [
    (1456942945, 'var_a', 123.4, 1, 'id_1'),
    (1456931076, 'var_b', 857.01, 1, 'id_1'),
    (1456932268, 'var_b', 871.74, 1, 'id_1'),
    (1456940055, 'var_b', 992.3, 2, 'id_1'),
    (1456932781, 'var_c', 861.3, 2, 'id_1'),
    (1456937186, 'var_c', 959.6, 3, 'id_1'),
    (1456934746, 'var_d', 0.12, 4, 'id_1'),
     (1456944444, 'var_e', 0.14, 4, 'id_1'),
      (1456946777, 'var_e', 0.23, 4, 'id_1'),
     (1456942945, 'var_a', 123.4, 1, 'id_2'),
    (1456931076, 'var_b', 847.01, 1, 'id_2'),
    (1456932268, 'var_b', 871.74, 1, 'id_2'),
    (1456940055, 'var_b', 932.3, 2, 'id_2'),
    (1456932781, 'var_c', 821.3, 3, 'id_2'),
    (1456937186, 'var_c', 969.6, 4, 'id_2'),
     (1456939879, 'var_f', 96.2, 4, 'id_2'),
    (1456934746, 'var_d', 0.12, 4, 'id_2')
]
schema = ["event_time", "variable", "value", "step", "ID"]
df = spark.createDataFrame(data, schema)

window_spec = Window.partitionBy("ID")
threshold_percentile = 0.25

# 计算每个ID的事件时间分位数
df_with_threshold = df.withColumn(
    "threshold",
    when(col("step") == 4,
         percentile_approx("event_time", threshold_percentile).over(window_spec))
         ).otherwise(None)


# 使用计算出来的阈值过滤第4步数据并合并
filtered_df = df_with_threshold.filter(
   ((col("step") < 4) )| ((col("step") == 4) & (col("event_time") <= col("threshold")))

).drop("threshold")

filtered_df.show()

该代码首先通过 Window.partitionBy("ID") 定义了一个基于 ID 的窗口。然后,针对第四步数据,计算其 25% 分位数阈值。 最后,我们合并前三步和过滤后的第四步数据,得到所需的结果。这种方法直接利用了PySpark的窗口函数特性,性能高效。

解决方案二:使用 groupByflatMap

此方案使用了 groupBy 对 ID 分组,并配合 flatMap 进行精细数据操作。它对每个ID先进行单独处理,然后把结果统一成一个新的DataFrame。相比第一种方法,逻辑上更加清晰易懂。

操作步骤:

  1. 对 DataFrame 按照 ID 进行分组。
  2. 针对每个 ID,进行步骤数据筛选以及时间范围的数据抽取。
  3. 将所有 ID 的操作结果使用 flatMap 进行扁平化处理并得到新的 DataFrame。
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, percentile_approx, collect_list
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType

spark = SparkSession.builder.appName("DataPartition").getOrCreate()


# 创建示例 DataFrame
data = [
    (1456942945, 'var_a', 123.4, 1, 'id_1'),
    (1456931076, 'var_b', 857.01, 1, 'id_1'),
    (1456932268, 'var_b', 871.74, 1, 'id_1'),
    (1456940055, 'var_b', 992.3, 2, 'id_1'),
    (1456932781, 'var_c', 861.3, 2, 'id_1'),
    (1456937186, 'var_c', 959.6, 3, 'id_1'),
    (1456934746, 'var_d', 0.12, 4, 'id_1'),
     (1456944444, 'var_e', 0.14, 4, 'id_1'),
      (1456946777, 'var_e', 0.23, 4, 'id_1'),
     (1456942945, 'var_a', 123.4, 1, 'id_2'),
    (1456931076, 'var_b', 847.01, 1, 'id_2'),
    (1456932268, 'var_b', 871.74, 1, 'id_2'),
    (1456940055, 'var_b', 932.3, 2, 'id_2'),
    (1456932781, 'var_c', 821.3, 3, 'id_2'),
    (1456937186, 'var_c', 969.6, 4, 'id_2'),
     (1456939879, 'var_f', 96.2, 4, 'id_2'),
    (1456934746, 'var_d', 0.12, 4, 'id_2')
]

schema = ["event_time", "variable", "value", "step", "ID"]
df = spark.createDataFrame(data, schema)
threshold_percentile = 0.25

def partition_data(rows):
    id_value = rows[0]["ID"]
    all_data = [row.asDict() for row in rows]

    first_steps = [d for d in all_data if d['step'] in [1, 2, 3]]
    step4_data = [d for d in all_data if d['step'] == 4]

    if step4_data:
         # 这里避免当没有 step=4 的时候 报错
         
        event_times_step4 = [row["event_time"] for row in step4_data]

        if len(event_times_step4) > 0 :  
                threshold = sorted(event_times_step4)[int(len(event_times_step4)* threshold_percentile) -1 ] # avoid error while no enough data
        else:
                 threshold = 0
                 
    
        filtered_step4 = [d for d in step4_data if d["event_time"] <= threshold]

        return first_steps + filtered_step4
    else:
         return first_steps
 


grouped_rdd = df.groupBy("ID").agg(collect_list(struct(*schema))).rdd.flatMap(lambda x : partition_data(x[1]) ) # must specify which schema to unpack here , but i'm using raw sql to create this schema
output_df =  spark.createDataFrame(grouped_rdd)
output_df.show()

此方法将 DataFrame 按照 ID 分组,并应用 flatMap 函数对每一组执行数据处理,此过程更加直观,容易理解。它直接操作 DataFrame 的内部结构,实现了数据切割。 需要注意的是,为避免因输入数据的改变造成的Schema 不匹配的问题, 需要通过 schema 来确定 RDD 的schema信息,从而进行 createDataFrame() 操作。

安全建议

  • 在实际应用中,数据切片和抽取操作应做好异常处理,应对数据质量问题和空值情况。
  • 避免直接操作大表中的 id 列数据,可以通过数据采样或者其它方式预先确认,确保数据量可控。

通过以上方案,可以有效且高效地利用PySpark对DataFrame进行灵活切片。根据实际需求和数据特点,可以选择适合的方案来解决问题。