模型规模增长,巨大模型难以训练
Model scale trend:
- 1 billion(Google, 2016)
- 12 trillion(FB, latest)
Recommendation model charactorization:
- 99% embedding:memory intensive
- 1% NN:computation intensive
提出一种有效的分布式训练系统,该系统从优化算法和分布式系统架构两方面做了精心设计
- Algorithm:提出一种Hybrid training算法,其中dense和embedding分别使用不同的同步机制
- System:设计并实现了系统Persia,能够支持上述的Hybrid training算法
主要是列benchmark结果,对比收敛加速比。为了highlight Persia在系统和算法上都有优势,作者分别展示了两组实验:
实验配置
CTR BenchMark Tasks
3个公开benchmark task和1个快手内部生产task。
- TaoBao-Ad
- Avazu-Ad
- Criteo-Ad
- Kwai-Video
cluster
8*8 NVIDIA V100 with 100Gbps network
models
结果展示
- 达到指定AUC时,经过的总时间对比
这项实验验证了Persia在系统上面的能力,Persia均先到达收敛线,并且速度可达7.12X
- 达到指定AUC是,经过总steps数对比
Persia新提出的Hybrid Algorithm收敛很稳定。
- Scalability
随着GPU数量增加,Persia在四个任务上的扩展性很好。最后一个任务虽然异步比同步扩展性好,但是上面的实验已经其AUC偏低(上面的图中最后一个实验可以看出)
贡献点
- 可以支持参数规模为100T大小的推荐模型
- 提出了Hybrid algorithm,能够既保证收敛,又能充分压榨计算资源,提升利用率
Benchmark
- 模型size,Racing to 100T,并且这是刚需
- SOTA能支持最大的模型也就12T,Persia是第一个到100T规模的system:
- 全同步和全异步都有问题:
- 全同步,同步开销太大,计算速度慢
- 全异步,计算爽了,收敛亏了
- 于是Persia横空出世,设计出一种Hybrid Algorithm,并且系统开源了
Core:提出了Hybrid algorithm+异构分布式系统
Hybrid algorithm:
- embedding:异步训练
- dense部分:同步训练
异构分布式系统:
允许同步和异步同时存在优化模型,同时对GPU,CPU,存储等资源管理业做了些工作。
最后,已开源可复现。
XDL:Embedding上PS,其他在worker GPU上(本来不就该是这样??)
Baidu PS:利用多级缓存,cache常访问的item
这是支持Hybrid algorithm的基本理论
上图中,理解时应注意,同一批sample的workflow顺序不可能打破,即GE->FC->BC->Dsync->PEG。因此Async和Hybrid中,GE和PEG能够并行是因为来自于不同的batch。
- Fully Async:不同batch间,连Dsync和Dense部分的计算都是异步的。上一批batch计算的grads更新时,此批batch的仍可能在计算,存在staleness;
- Naive Hybrid:同batch的Dsync和计算保证同步,但与不同batch的GE和PEG仍然异步;
- Persia:在上述Naive Hybrid之上,加了Dsync和BC的overlap。
- Embedding Worker和NN worker上的Buffer cache
每个数据都有ID feature和非ID feature,一个要走Embedding worker,一个要走NN worker,最终都在NN worker汇合,汇合时需要能够再次将同一sample的数据拼凑在一起,靠的就是data loader上生成的ID以及两类worker上的buffer机制。
作用是为每个Sample找对应的worker:
- NN worker上的Buffer:非ID feature发送到NN worker上,最终要ID feature对应的Embedding worker的结果做对应。NN worker上就靠这个unique id做pull
- Embedding worker上的Buffer:记录在parameter从哪里来的,最后得把gradients push回去。
- 显存管理和cache(不知道有啥用,可能是为了应对动态embedding而做的LRU)
- 通信优化:
- NN部分的AllReduce通信,重在计算通信overlap,用的是Bagua;
- RPC通信优化:因传输的是连续空间的大Tensor,不想引入序列化开销。作为代替,使用了zero-copy serialization和deserialization的方法,直接对memory layout做序列化。
- PS上的访问balance问题:先均匀shuffle,再平均放在多个ps上,能够缓解这个问题。
- 通信压缩:
- 无损:其实就是多级索引。放弃对每个sample使用大的int64的ID,而对整个batch使用一个大的ID,每个小example直接使用index。
- 有损:FP16
- 容错机制
证明思路:只有embedding是stale的,最后推导出convergence rate不大于vainila SGD+staleness影响。而staleness只有embedding引入,其值为远小于1,因此基本可以认为趋近于0,收敛性与vanila SGD相当。