Databricks Scala 高效并行调用 REST API (百万级数据)
2025-04-01 18:02:47
好的,这是你要的博客文章:
Databricks+Scala 高效并行调用 REST API 获取百万级数据
拉取大量数据时,单个 API 请求经常遇到限制。比如,你想从某个 API 获取洛杉矶县一百万栋建筑的信息,但一次调用最多只能返回一万条记录。这种情况下,就得多次调用 API 来分页获取数据。
直接用 Python 的 asyncio
来做,像下面这样,虽然用了异步,但拉取一百万条数据可能还是慢得让人抓狂(比如十几分钟甚至更久):
(这里可以省略用户提供的 Python 代码,因为文章重点是 Scala 解决方案)
当同步或异步 Python 请求都扛不住时,自然会想到用更强力的工具。Databricks 配合 Scala,利用 Apache Spark 的分布式计算能力,就能把这些 API 调用分散到多个计算节点上并行执行,大大缩短数据获取时间。
这篇文章就来聊聊怎么用 Databricks 和 Scala 来解决这个问题。
为啥慢?瓶颈在哪?
原来的 Python 脚本(即使是异步的)之所以慢,主要原因有几个:
- 串行或有限并发 :虽然
asyncio
提供了并发能力,但它仍然在单个 Python 进程内运行。网络 I/O 是并发了,可 CPU 密集型的操作(比如 JSON 解析)和 GIL(全局解释器锁)在某些情况下还是会限制整体吞吐量。并发度也受限于你设置的Semaphore
和机器资源。 - 网络延迟累积 :一百万数据,假设每批 1000 条(如示例),就需要 1000 次 API 调用。即使每次调用只花 100 毫秒,累积起来也是 100 秒的网络时间,这还不算服务器处理时间和数据传输时间。如果并发度不高,总时间会被最慢的一批请求拖累。
- 单点处理压力 :所有请求的发起、管理和数据汇总都在一个地方(你的本地机器或单个驱动节点),容易成为瓶颈。
Databricks + Scala 并行方案
思路很简单:把需要发起的 API 调用(根据分页逻辑确定)变成一个个独立的任务,然后让 Spark 集群里的各个 Executor
(工作节点)去并行处理这些任务。最后把结果汇总起来。
(提示:这里可以用一个简单的图示说明 Driver 如何分发任务给 Executors,每个 Executor 独立调用 API)
第一步:搞清楚要调多少次 API
和 Python 脚本里一样,咱们首先得知道总共有多少条符合条件的记录,才能算出需要分页调用多少次。大部分支持分页的 API 会提供一个只返回总数的方法,或者在首次请求时就返回总数。
假设咱们的目标 API (洛杉矶建筑数据)允许通过特定参数(比如 returnCountOnly=true
)获取总数。
import requests // 方便起见,用 requests-scala 库,简洁易用
import ujson // 用于解析 JSON
// API 地址和基础查询参数
val baseUrl = "https://services.arcgis.com/RmCCgQtiZLDCtblq/arcgis/rest/services/Countywide_Building_Outlines/FeatureServer/1/query"
val baseQueryParams = Map(
"where" -> "(HEIGHT < 33) AND UseType = 'RESIDENTIAL' AND SitusCity IN('LOS ANGELES CA','BEVERLY HILLS CA', 'PALMDALE')",
"outFields" -> "*",
"outSR" -> "4326",
"f" -> "json"
)
// 获取总记录数
def getTotalRecordCount(): Int = {
val params = baseQueryParams ++ Map("returnCountOnly" -> "true")
try {
val response = requests.get(baseUrl, params = params)
if (response.statusCode == 200) {
val json = ujson.read(response.text)
json("count").num.toInt // 注意:这里的 key 可能需要根据实际 API 返回调整
} else {
println(s"获取总数失败,状态码: ${response.statusCode}, 消息: ${response.text}")
0 // 或者抛出异常
}
} catch {
case e: Exception =>
println(s"请求总数时发生异常: ${e.getMessage}")
0 // 或者抛出异常
}
}
val totalCount = getTotalRecordCount()
println(s"总记录数: $totalCount")
// 确定分页参数
val batchSize = 10000 // ESRI API 限制通常更高,可以设为 10000
val offsets = (0 until totalCount by batchSize).toList // 生成所有需要的 offset 值
println(s"需要发起 ${offsets.size} 次 API 调用")
注意 :
- 上面的
requests-scala
和ujson
需要在 Databricks 集群上安装。可以在 Notebook 里使用%pip install requests-scala ujson
(如果用 Python kernel 混合) 或者在集群库设置里添加 Maven 坐标com.lihaoyi::requests:0.8.0
和com.lihaoyi::ujson:2.0.0
(版本号请根据需要选择)。 - 错误处理要做好,网络请求总可能失败。
第二步:并行发起 API 请求
拿到所有需要请求的 offset
值列表后,就可以用 Spark 把它们分发出去并行处理了。
import spark.implicits._ // 导入隐式转换为 Dataset/DataFrame
import org.apache.spark.sql.{DataFrame, SparkSession}
import scala.util.{Try, Success, Failure}
// 把 offsets 列表变成 Spark 的 Dataset。让每个 offset 成为一个任务单元
val offsetsDS = offsets.toDS()
// 定义一个 case class 来匹配返回的建筑数据结构,方便后面转 DataFrame
// 注意:字段名和类型要跟 API 返回的 feature 的 attributes 匹配
case class BuildingFeature(
OBJECTID: Option[Long], // 字段名严格按照 API 返回来
JURISDICTN: Option[String],
AIN: Option[String],
UseType: Option[String],
HEIGHT: Option[Double],
SitusCity: Option[String]
// ... 其他你需要的字段
// 使用 Option 类型可以更好地处理 null 或缺失值
)
// 这个函数将在每个 Executor 上执行,处理一部分 offset
def fetchBatchData(offset: Int): Seq[BuildingFeature] = {
val queryParams = baseQueryParams ++ Map(
"resultOffset" -> offset.toString,
"resultRecordCount" -> batchSize.toString // 确保请求正确的批次大小
)
println(s"正在处理 offset: $offset") // 可以在 Executor 日志中看到进度
Try {
// 重要:HttpClient 实例建议在 mapPartitions 内部创建和管理,避免序列化问题和资源浪费
// 这里用 requests-scala 举例
val response = requests.get(baseUrl, params = queryParams, readTimeout = 60000, connectTimeout = 10000) // 设置超时
if (response.statusCode == 200) {
val json = ujson.read(response.text)
// 解析 features 数组
json.obj.get("features") match { // 使用 .obj.get 安全访问 key
case Some(featuresArray) =>
featuresArray.arr.map { featureJson =>
val attributes = featureJson("attributes")
// 把 JSON 对象映射到 Case Class
BuildingFeature(
attributes.obj.get("OBJECTID").map(_.num.toLong),
attributes.obj.get("JURISDICTN").map(_.str),
attributes.obj.get("AIN").map(_.str),
attributes.obj.get("UseType").map(_.str),
attributes.obj.get("HEIGHT").map(_.num),
attributes.obj.get("SitusCity").map(_.str)
// ... 映射其他字段
)
}.toSeq // 确保返回 Seq[BuildingFeature]
case None =>
println(s"Offset $offset 返回的数据中未找到 'features' 字段。JSON: ${response.text.take(500)}") // 打印部分 JSON 帮助调试
Seq.empty[BuildingFeature] // 返回空序列
}
} else {
println(s"请求 offset $offset 失败,状态码: ${response.statusCode}, 消息: ${response.text.take(500)}")
Seq.empty[BuildingFeature] // 失败则返回空序列
}
} match {
case Success(features) => features
case Failure(e) =>
println(s"处理 offset $offset 时发生异常: ${e.getMessage}")
// 考虑添加重试逻辑,或者标记失败的 offset
Seq.empty[BuildingFeature] // 异常也返回空序列
}
}
// 使用 flatMap,因为每个 offset 请求返回的是一个特征 *列表* (Seq[BuildingFeature])
// flatMap 会把所有 Executor 返回的列表“压平”成一个大的 RDD/Dataset
// cache() / persist() 可选,如果后续还要多次操作这个结果集,缓存能提速
val featuresDS = offsetsDS.flatMap(offset => fetchBatchData(offset)).cache()
// 触发计算并统计结果
val retrievedCount = featuresDS.count()
println(s"成功获取并解析了 $retrievedCount 条记录")
关键点解释 :
offsets.toDS()
: 把 Scala 的List[Int]
转换成 Spark 的Dataset[Int]
。这是分布式处理的起点。case class BuildingFeature
: 定义数据结构,让 JSON 解析和后续转 DataFrame 更方便、类型更安全。字段名要跟 API 返回的 JSONattributes
里的键完全一致。用Option[T]
类型处理可能缺失或为null
的字段是个好习惯。fetchBatchData(offset: Int)
: 这是核心函数,它接收一个offset
值,发起 API 请求,解析 JSON,然后返回一个Seq[BuildingFeature]
。这个函数会在 Spark 的 Executor 节点上并行执行 。Try/Success/Failure
: 用于捕获 HTTP 请求和 JSON 解析中可能发生的异常,增加程序的健壮性。requests.get(...)
: 在fetchBatchData
内部发起 HTTP GET 请求。记得设置合理的readTimeout
和connectTimeout
,避免长时间卡死。ujson.read()
和解析逻辑 : 使用ujson
解析返回的 JSON 字符串。注意要准确地从 JSON 结构中提取features
数组,并把每个feature
的attributes
映射到BuildingFeature
case class。这里的.obj.get("key").map(...)
模式比直接("key")
更安全,能处理 key 不存在的情况。offsetsDS.flatMap(...)
: 这是执行并行化的关键。flatMap
会把offsetsDS
中的每个offset
值传递给fetchBatchData
函数。由于fetchBatchData
返回Seq[BuildingFeature]
,flatMap
会把所有 Executor 返回的Seq
合并(“压平”)成一个单一的Dataset[BuildingFeature]
。.cache()
: 把featuresDS
缓存到内存(或磁盘)。如果后续需要对这个Dataset
进行多次操作(比如统计、筛选、保存),缓存能避免重复执行前面的 API 调用和解析步骤。如果只用一次(比如直接保存),可以不缓存。featuresDS.count()
: 这是一个 Action 操作,会触发上面所有flatMap
的实际计算。计算完成后,它返回获取到的总记录数。
第三步:转换成 DataFrame 并使用
拿到了 Dataset[BuildingFeature]
之后,通常需要把它转换成 DataFrame 进行后续分析或保存。
// Dataset[CaseClass] 可以非常方便地直接转成 DataFrame
val buildingsDF = featuresDS.toDF()
// 显示 DataFrame 结构和一些数据
buildingsDF.printSchema()
buildingsDF.show(10, truncate = false) // 显示前10条,不截断列内容
// 后续可以进行各种 Spark SQL 操作或 DataFrame 操作
println(s"DataFrame 中的总记录数: ${buildingsDF.count()}") // 这次 count 会很快,因为 RDD/Dataset 被缓存了
// 示例:按 UseType 统计建筑数量
buildingsDF.groupBy("UseType").count().orderBy($"count".desc).show()
// 示例:保存为 Parquet 文件(推荐,高效、带 Schema)
// buildingsDF.write.mode("overwrite").parquet("/path/to/save/la_buildings.parquet")
// 注意替换成你在 DBFS 或挂载存储上的实际路径
性能和进阶技巧
-
调整并发度 :
offsetsDS
在flatMap
前可以通过repartition(N)
来控制初始分区数N
。N
应该设置得比集群核心数多一些(比如 2-4 倍),以保证所有核心都能充分利用。val offsetsDS = offsets.toDS().repartition(spark.sparkContext.defaultParallelism * 3)
。- 过高的并发可能会给 API 服务器造成太大压力,导致请求被拒绝或限流。需要找到一个平衡点。
-
优化 HTTP Client :
requests-scala
简单易用,但对于超高性能场景,可能需要考虑更底层的库(如sttp
配合后端,如Akka HTTP
或Netty
),或者在mapPartitions
级别管理连接池(稍微复杂些)。在mapPartitions
内部创建客户端实例通常是推荐做法,避免序列化和每个任务都创建连接的开销。
// 示例:使用 mapPartitions 管理客户端实例(伪代码) offsetsDS.rdd.mapPartitions { partitionIterator => // 每个分区(在同一个 Executor 上运行的一组任务)创建一个 Client 实例 // val httpClient = createHttpClient() // 初始化 HTTP Client val results = partitionIterator.flatMap { offset => // 使用同一个 httpClient 发起多次请求 fetchBatchDataWithClient(httpClient, offset) // 修改 fetch 函数接收 Client } // 可选:关闭/清理 httpClient // closeHttpClient(httpClient) results // 返回这个分区的处理结果 }.toDS()
-
健壮的错误处理与重试 :
- 简单的
Try/Catch
可能不够。对于可能瞬时失败的网络请求,实现带退避策略(Exponential Backoff)的自动重试机制会更可靠。可以用简单的Thread.sleep
实现基本重试,或者引入retry
库。 - 记录失败的
offset
,方便后续单独处理或排查问题。
- 简单的
-
API 速率限制 :
- 如果 API 有严格的速率限制(Rate Limiting),即使并行化,也可能很快达到上限。你可能需要在
fetchBatchData
函数里加入适当的延时(Thread.sleep(milliseconds)
),或者使用更复杂的令牌桶/漏桶算法来控制请求速率。一定要尊重 API 提供方的使用条款和限制!
- 如果 API 有严格的速率限制(Rate Limiting),即使并行化,也可能很快达到上限。你可能需要在
-
处理超大数据量 (远超内存) :
- 如果一百万条记录(或解析后的数据)对于单个 Executor 的内存来说仍然太大,
flatMap
后直接.toDF()
可能导致 OOM (Out Of Memory)。这时可以考虑:- 在
fetchBatchData
里直接返回轻量级的数据结构(比如只包含必要字段的String
或Tuple
),减少内存占用。 - 不使用
.cache()
,或者使用StorageLevel.MEMORY_AND_DISK
。 - 处理完一批数据后直接写入中间存储(如 Parquet 文件),最后再合并。
- 在
- 如果一百万条记录(或解析后的数据)对于单个 Executor 的内存来说仍然太大,
安全建议
- API 密钥/令牌管理 : 如果 API 需要认证(比如 API Key 或 Bearer Token),绝对不要 硬编码在代码里。应该使用 Databricks 的 Secrets 功能来安全地存储和访问这些凭证。
// 读取 Databricks Secret // val apiKey = dbutils.secrets.get(scope = "your-secret-scope", key = "api-key") // 然后在请求头或参数中带上 apiKey // val response = requests.get(baseUrl, params = queryParams, headers = Map("Authorization" -> s"Bearer $apiKey"))
通过这种 Spark 并行化改造,相比于单机 Python 脚本,你应该能看到数据获取速度的巨大提升,将原本几十分钟甚至数小时的任务缩短到几分钟以内,具体取决于你的 Databricks 集群规模和网络条件以及 API 服务器的响应能力。