返回

Databricks Scala 高效并行调用 REST API (百万级数据)

python

好的,这是你要的博客文章:


Databricks+Scala 高效并行调用 REST API 获取百万级数据

拉取大量数据时,单个 API 请求经常遇到限制。比如,你想从某个 API 获取洛杉矶县一百万栋建筑的信息,但一次调用最多只能返回一万条记录。这种情况下,就得多次调用 API 来分页获取数据。

直接用 Python 的 asyncio 来做,像下面这样,虽然用了异步,但拉取一百万条数据可能还是慢得让人抓狂(比如十几分钟甚至更久):

(这里可以省略用户提供的 Python 代码,因为文章重点是 Scala 解决方案)

当同步或异步 Python 请求都扛不住时,自然会想到用更强力的工具。Databricks 配合 Scala,利用 Apache Spark 的分布式计算能力,就能把这些 API 调用分散到多个计算节点上并行执行,大大缩短数据获取时间。

这篇文章就来聊聊怎么用 Databricks 和 Scala 来解决这个问题。

为啥慢?瓶颈在哪?

原来的 Python 脚本(即使是异步的)之所以慢,主要原因有几个:

  1. 串行或有限并发 :虽然 asyncio 提供了并发能力,但它仍然在单个 Python 进程内运行。网络 I/O 是并发了,可 CPU 密集型的操作(比如 JSON 解析)和 GIL(全局解释器锁)在某些情况下还是会限制整体吞吐量。并发度也受限于你设置的 Semaphore 和机器资源。
  2. 网络延迟累积 :一百万数据,假设每批 1000 条(如示例),就需要 1000 次 API 调用。即使每次调用只花 100 毫秒,累积起来也是 100 秒的网络时间,这还不算服务器处理时间和数据传输时间。如果并发度不高,总时间会被最慢的一批请求拖累。
  3. 单点处理压力 :所有请求的发起、管理和数据汇总都在一个地方(你的本地机器或单个驱动节点),容易成为瓶颈。

Databricks + Scala 并行方案

思路很简单:把需要发起的 API 调用(根据分页逻辑确定)变成一个个独立的任务,然后让 Spark 集群里的各个 Executor (工作节点)去并行处理这些任务。最后把结果汇总起来。

分布式 API 调用示意图 (Conceptual)
(提示:这里可以用一个简单的图示说明 Driver 如何分发任务给 Executors,每个 Executor 独立调用 API)

第一步:搞清楚要调多少次 API

和 Python 脚本里一样,咱们首先得知道总共有多少条符合条件的记录,才能算出需要分页调用多少次。大部分支持分页的 API 会提供一个只返回总数的方法,或者在首次请求时就返回总数。

假设咱们的目标 API (洛杉矶建筑数据)允许通过特定参数(比如 returnCountOnly=true)获取总数。

import requests // 方便起见,用 requests-scala 库,简洁易用
import ujson // 用于解析 JSON

// API 地址和基础查询参数
val baseUrl = "https://services.arcgis.com/RmCCgQtiZLDCtblq/arcgis/rest/services/Countywide_Building_Outlines/FeatureServer/1/query"
val baseQueryParams = Map(
    "where" -> "(HEIGHT < 33) AND UseType = 'RESIDENTIAL' AND SitusCity IN('LOS ANGELES CA','BEVERLY HILLS CA',  'PALMDALE')",
    "outFields" -> "*",
    "outSR" -> "4326",
    "f" -> "json"
)

// 获取总记录数
def getTotalRecordCount(): Int = {
    val params = baseQueryParams ++ Map("returnCountOnly" -> "true")
    try {
        val response = requests.get(baseUrl, params = params)
        if (response.statusCode == 200) {
            val json = ujson.read(response.text)
            json("count").num.toInt // 注意:这里的 key 可能需要根据实际 API 返回调整
        } else {
            println(s"获取总数失败,状态码: ${response.statusCode}, 消息: ${response.text}")
            0 // 或者抛出异常
        }
    } catch {
        case e: Exception =>
            println(s"请求总数时发生异常: ${e.getMessage}")
            0 // 或者抛出异常
    }
}

val totalCount = getTotalRecordCount()
println(s"总记录数: $totalCount")

// 确定分页参数
val batchSize = 10000 // ESRI API 限制通常更高,可以设为 10000
val offsets = (0 until totalCount by batchSize).toList // 生成所有需要的 offset 值
println(s"需要发起 ${offsets.size} 次 API 调用")

注意 :

  • 上面的 requests-scalaujson 需要在 Databricks 集群上安装。可以在 Notebook 里使用 %pip install requests-scala ujson (如果用 Python kernel 混合) 或者在集群库设置里添加 Maven 坐标 com.lihaoyi::requests:0.8.0com.lihaoyi::ujson:2.0.0 (版本号请根据需要选择)。
  • 错误处理要做好,网络请求总可能失败。

第二步:并行发起 API 请求

拿到所有需要请求的 offset 值列表后,就可以用 Spark 把它们分发出去并行处理了。

import spark.implicits._ // 导入隐式转换为 Dataset/DataFrame
import org.apache.spark.sql.{DataFrame, SparkSession}
import scala.util.{Try, Success, Failure}

// 把 offsets 列表变成 Spark 的 Dataset。让每个 offset 成为一个任务单元
val offsetsDS = offsets.toDS()

// 定义一个 case class 来匹配返回的建筑数据结构,方便后面转 DataFrame
// 注意:字段名和类型要跟 API 返回的 feature 的 attributes 匹配
case class BuildingFeature(
    OBJECTID: Option[Long], // 字段名严格按照 API 返回来
    JURISDICTN: Option[String],
    AIN: Option[String],
    UseType: Option[String],
    HEIGHT: Option[Double],
    SitusCity: Option[String]
    // ... 其他你需要的字段
    // 使用 Option 类型可以更好地处理 null 或缺失值
)

// 这个函数将在每个 Executor 上执行,处理一部分 offset
def fetchBatchData(offset: Int): Seq[BuildingFeature] = {
    val queryParams = baseQueryParams ++ Map(
        "resultOffset" -> offset.toString,
        "resultRecordCount" -> batchSize.toString // 确保请求正确的批次大小
    )
    
    println(s"正在处理 offset: $offset") // 可以在 Executor 日志中看到进度

    Try {
        // 重要:HttpClient 实例建议在 mapPartitions 内部创建和管理,避免序列化问题和资源浪费
        // 这里用 requests-scala 举例
        val response = requests.get(baseUrl, params = queryParams, readTimeout = 60000, connectTimeout = 10000) // 设置超时

        if (response.statusCode == 200) {
            val json = ujson.read(response.text)
            // 解析 features 数组
            json.obj.get("features") match { // 使用 .obj.get 安全访问 key
                case Some(featuresArray) =>
                    featuresArray.arr.map { featureJson =>
                        val attributes = featureJson("attributes")
                        // 把 JSON 对象映射到 Case Class
                        BuildingFeature(
                            attributes.obj.get("OBJECTID").map(_.num.toLong),
                            attributes.obj.get("JURISDICTN").map(_.str),
                            attributes.obj.get("AIN").map(_.str),
                            attributes.obj.get("UseType").map(_.str),
                            attributes.obj.get("HEIGHT").map(_.num),
                            attributes.obj.get("SitusCity").map(_.str)
                            // ... 映射其他字段
                        )
                    }.toSeq // 确保返回 Seq[BuildingFeature]
                case None =>
                    println(s"Offset $offset 返回的数据中未找到 'features' 字段。JSON: ${response.text.take(500)}") // 打印部分 JSON 帮助调试
                    Seq.empty[BuildingFeature] // 返回空序列
            }
        } else {
            println(s"请求 offset $offset 失败,状态码: ${response.statusCode}, 消息: ${response.text.take(500)}")
            Seq.empty[BuildingFeature] // 失败则返回空序列
        }
    } match {
        case Success(features) => features
        case Failure(e) =>
            println(s"处理 offset $offset 时发生异常: ${e.getMessage}")
            // 考虑添加重试逻辑,或者标记失败的 offset
            Seq.empty[BuildingFeature] // 异常也返回空序列
    }
}

// 使用 flatMap,因为每个 offset 请求返回的是一个特征 *列表* (Seq[BuildingFeature])
// flatMap 会把所有 Executor 返回的列表“压平”成一个大的 RDD/Dataset
// cache() / persist() 可选,如果后续还要多次操作这个结果集,缓存能提速
val featuresDS = offsetsDS.flatMap(offset => fetchBatchData(offset)).cache()

// 触发计算并统计结果
val retrievedCount = featuresDS.count()
println(s"成功获取并解析了 $retrievedCount 条记录")

关键点解释 :

  • offsets.toDS() : 把 Scala 的 List[Int] 转换成 Spark 的 Dataset[Int]。这是分布式处理的起点。
  • case class BuildingFeature : 定义数据结构,让 JSON 解析和后续转 DataFrame 更方便、类型更安全。字段名要跟 API 返回的 JSON attributes 里的键完全一致。用 Option[T] 类型处理可能缺失或为 null 的字段是个好习惯。
  • fetchBatchData(offset: Int) : 这是核心函数,它接收一个 offset 值,发起 API 请求,解析 JSON,然后返回一个 Seq[BuildingFeature]这个函数会在 Spark 的 Executor 节点上并行执行
  • Try/Success/Failure : 用于捕获 HTTP 请求和 JSON 解析中可能发生的异常,增加程序的健壮性。
  • requests.get(...) : 在 fetchBatchData 内部发起 HTTP GET 请求。记得设置合理的 readTimeoutconnectTimeout,避免长时间卡死。
  • ujson.read() 和解析逻辑 : 使用 ujson 解析返回的 JSON 字符串。注意要准确地从 JSON 结构中提取 features 数组,并把每个 featureattributes 映射到 BuildingFeature case class。这里的 .obj.get("key").map(...) 模式比直接 ("key") 更安全,能处理 key 不存在的情况。
  • offsetsDS.flatMap(...) : 这是执行并行化的关键。flatMap 会把 offsetsDS 中的每个 offset 值传递给 fetchBatchData 函数。由于 fetchBatchData 返回 Seq[BuildingFeature]flatMap 会把所有 Executor 返回的 Seq 合并(“压平”)成一个单一的 Dataset[BuildingFeature]
  • .cache() : 把 featuresDS 缓存到内存(或磁盘)。如果后续需要对这个 Dataset 进行多次操作(比如统计、筛选、保存),缓存能避免重复执行前面的 API 调用和解析步骤。如果只用一次(比如直接保存),可以不缓存。
  • featuresDS.count() : 这是一个 Action 操作,会触发上面所有 flatMap 的实际计算。计算完成后,它返回获取到的总记录数。

第三步:转换成 DataFrame 并使用

拿到了 Dataset[BuildingFeature] 之后,通常需要把它转换成 DataFrame 进行后续分析或保存。

// Dataset[CaseClass] 可以非常方便地直接转成 DataFrame
val buildingsDF = featuresDS.toDF()

// 显示 DataFrame 结构和一些数据
buildingsDF.printSchema()
buildingsDF.show(10, truncate = false) // 显示前10条,不截断列内容

// 后续可以进行各种 Spark SQL 操作或 DataFrame 操作
println(s"DataFrame 中的总记录数: ${buildingsDF.count()}") // 这次 count 会很快,因为 RDD/Dataset 被缓存了

// 示例:按 UseType 统计建筑数量
buildingsDF.groupBy("UseType").count().orderBy($"count".desc).show()

// 示例:保存为 Parquet 文件(推荐,高效、带 Schema)
// buildingsDF.write.mode("overwrite").parquet("/path/to/save/la_buildings.parquet") 
// 注意替换成你在 DBFS 或挂载存储上的实际路径

性能和进阶技巧

  1. 调整并发度 :

    • offsetsDSflatMap 前可以通过 repartition(N) 来控制初始分区数 NN 应该设置得比集群核心数多一些(比如 2-4 倍),以保证所有核心都能充分利用。val offsetsDS = offsets.toDS().repartition(spark.sparkContext.defaultParallelism * 3)
    • 过高的并发可能会给 API 服务器造成太大压力,导致请求被拒绝或限流。需要找到一个平衡点。
  2. 优化 HTTP Client :

    • requests-scala 简单易用,但对于超高性能场景,可能需要考虑更底层的库(如 sttp 配合后端,如 Akka HTTPNetty),或者在 mapPartitions 级别管理连接池(稍微复杂些)。在 mapPartitions 内部创建客户端实例通常是推荐做法,避免序列化和每个任务都创建连接的开销。
    // 示例:使用 mapPartitions 管理客户端实例(伪代码)
    offsetsDS.rdd.mapPartitions { partitionIterator =>
        // 每个分区(在同一个 Executor 上运行的一组任务)创建一个 Client 实例
        // val httpClient = createHttpClient() // 初始化 HTTP Client
        val results = partitionIterator.flatMap { offset =>
            // 使用同一个 httpClient 发起多次请求
            fetchBatchDataWithClient(httpClient, offset) // 修改 fetch 函数接收 Client
        }
        // 可选:关闭/清理 httpClient
        // closeHttpClient(httpClient)
        results // 返回这个分区的处理结果
    }.toDS() 
    
  3. 健壮的错误处理与重试 :

    • 简单的 Try/Catch 可能不够。对于可能瞬时失败的网络请求,实现带退避策略(Exponential Backoff)的自动重试机制会更可靠。可以用简单的 Thread.sleep 实现基本重试,或者引入 retry 库。
    • 记录失败的 offset,方便后续单独处理或排查问题。
  4. API 速率限制 :

    • 如果 API 有严格的速率限制(Rate Limiting),即使并行化,也可能很快达到上限。你可能需要在 fetchBatchData 函数里加入适当的延时(Thread.sleep(milliseconds)),或者使用更复杂的令牌桶/漏桶算法来控制请求速率。一定要尊重 API 提供方的使用条款和限制!
  5. 处理超大数据量 (远超内存) :

    • 如果一百万条记录(或解析后的数据)对于单个 Executor 的内存来说仍然太大,flatMap 后直接 .toDF() 可能导致 OOM (Out Of Memory)。这时可以考虑:
      • fetchBatchData 里直接返回轻量级的数据结构(比如只包含必要字段的 StringTuple),减少内存占用。
      • 不使用 .cache(),或者使用 StorageLevel.MEMORY_AND_DISK
      • 处理完一批数据后直接写入中间存储(如 Parquet 文件),最后再合并。

安全建议

  • API 密钥/令牌管理 : 如果 API 需要认证(比如 API Key 或 Bearer Token),绝对不要 硬编码在代码里。应该使用 Databricks 的 Secrets 功能来安全地存储和访问这些凭证。
    // 读取 Databricks Secret
    // val apiKey = dbutils.secrets.get(scope = "your-secret-scope", key = "api-key")
    // 然后在请求头或参数中带上 apiKey
    // val response = requests.get(baseUrl, params = queryParams, headers = Map("Authorization" -> s"Bearer $apiKey"))
    

通过这种 Spark 并行化改造,相比于单机 Python 脚本,你应该能看到数据获取速度的巨大提升,将原本几十分钟甚至数小时的任务缩短到几分钟以内,具体取决于你的 Databricks 集群规模和网络条件以及 API 服务器的响应能力。