OpenSora Inference Sumamry
Model Architecture
OpenSora 的整体结构如下,Embedding Layer 包括 5 个
- PatchEmbedding3D: 对输入隐藏层的噪声进行 3D 卷积
Conv3d(4, 1152, kernel_size=(1, 2, 2), stride=(1, 2, 2))
- TimestepEmbedder: 对当前的时间步进行 Sinusoidal Embedding 后经过两个线性层
- SizeEmbedder: scalar 版本的 TimestepEmbedder
- t_block:
Linear(in_features=1152, out_features=6912, bias=True)
生成 MHSA 和 MLP 所需要的 shift, scale & gate 参数. - CaptionEmbedder: 对编码后的 prompt 经过两个线性层。
主干由 28 个 STDiT3 Block 组成。每个 STDiT3 block 由一个 spatial block + temporal block 组成。spatial/temporal block 由 MHSA + MHCA + FFN 组成。不同的是 spatial block 中的 MHSA 的序列大小为空间维度 (HW),时间维度 T 被看作 batch_size 处理,temporal block 则是反过来。
T2IFinal Layer 由一个 LayerNorm + Linear 组成,将维度缩放回 patch_size*vae_outchannels
再 reshape 成 (B, C, T, H, W)
想要生成 t
秒分辨率为 (H,W)
的视频,采用 class-free guidance,则 B=2
. 经过 VAE 后的 shape 为 (B, 4, t*25.5*5//17, H/8, W/8)
,之后经过卷积时如果不能被 patchsize 整除需要先 padding 后再进行,最后对应 Transform block 的输入 shape 为 (B, C, 7.5t, H/16, W/16)
. 以下是 DeepLink 测试配置下的具体值
shape | S | T | numel=ST |
---|---|---|---|
204_640_360 | 920 | 60 | 55200 |
408_640_360 | 920 | 120 | 110400 |
51_1280_720 | 3600 | 15 | 54000 |
102_1280_720 | 3600 | 30 | 108000 |
Overall Computation
一个标准的 Transformer block 的 GEMM 计算量包括 QKVLinear(6bshh), Q@K.T(2bssh), Score@V(2bssh), OLinear(2bshh) 以及 FFN 的两个 Linear(16bshh),共计 (24bshh + 4bssh). OpenSora 的一个Transformer block 由 spatial+temporal_attention + 2cross_attention + 2FFN 组成,一共有 28 个这样的 block. Spatial Attention 中 b = BT, s = S, Temporal Attention 中 b = BS, S = T, Cross Attention 中 b = B, S_q = TS, S_k = S_v = TOKEN (最大 prompt 长度, 300) 代入公式可得 OpenSora
- Spatial Attention Comp =
8BTShh+4BTSSh
- Temporal Attention Comp =
8BSThh+4BSTTh
- Cross Attention Comp =
2B(TS*2+TOKEN*2)hh+4BST*TOKEN*h
,TS*2
代表 QO Linear,TOKEN*2
代表 KV Linear - Feed Forward Network Comp = 16BSThh
则一个 OpenSora Transformer Block 总计算量为56BTShh+4BTSh(S+T+2*TOKEN)+8B*TOKEN*hh
FLOPs. 令 N = TS 则可化简为(56BN+8B*TOKEN)*hh+4BNh(S+T+2*TOKEN)
. 对于生成 shape 为 204_640_360 的视频每个 block TFLOPs =(56*2*55200+8*2*300)*1152^2+4*2*55200*1152*(920+60+300) = 8.86T
, 整个 Backbone GEMMs TFLOP=8.86*28=256.94T
用 torch-xla 进行一遍推理后 trace mhlo (后文叙述) 得到的结果如下,考虑到前面的 Embedding 层和后面的 T2IFinalLayer,手算结果基本准确。
shape | GEMM TFLOPS | Vector TFLOPS |
---|---|---|
204_640_360 | 260.978 | 0.795 |
408_640_360 | 523.479 | 1.608 |
51_1280_720 | 292.026 | 1.160 |
102_1280_720 | 584.284 | 2.324 |
Overall Communication
Megatron
Megatron 张量并行中每个 transformer block 的通信量为 Attention 计算和 FFN 的各一次 AllReduce, 一次 AllRedce 通信量为 2BTSh(N-1)/N*2
Bytes,一个 block 有 3*2
次,总共便是 24BTSh(N-1)/N Bytes
.
Two Dimension
2D 张量并行将输入分别将 B 和 M 沿着 X 和 Y 维度切分 b_xNm_y,将第一个矩阵的行和列分别沿着 X 和 Y 维度切分 m_xh_y,进行乘法前,输入会沿着 Y 轴进行 All-Gather b_xNM,通信量为 BTSM/N*(N_y-1)*2
Bytes,权重沿着 X 轴进行 All-Gather Mh_y,通信量为 MH/N*(N_x-1)*2
Bytes,这样输出的 B 和 H 沿着 X 和 Y 维度切分 b_xNh_y. 第二个矩阵的行和列分别沿着 Y 和 X 维度切分 h_ym_x,进行乘法前,权重沿着 X 轴进行 All-Gather h_yM,通信量为 MH/N*(N_x-1)*2
Bytes. 这样相乘后 Y 轴的每个设备上都存有部分和结果 b_xNM,再沿着 Y 轴进行 Reduce-Scatter b_xNm_y,通信量为 BTSM/N*(N_y-1)*2
Bytes.
一次这样切分计算通信量总计为 (BTSM/N*(N_y-1)*2 + MH/N*(N_x-1)*2)*2
Bytes. 一个 Transformer block 有 3*2
次,其中Attention 和 CrossAttention的 QKVOLinear(M=1152, H=1152), FFNUp&Down(M=1152, H=4H),考虑 cfg B 只能为 2,因此 N_x=2,N_y=8
.
1 |
|
DS-Ulysses 每个 Transformer block 的通信量 QKVO 各一次 All2All:
Spatial/Temporal Attention:
Cross Attention:
加起来便是
Ring-Attention 通过将 QKV 沿着序列维度切分到每个设备上,计算的同时异步通信 KV. 因此 Spatial/Temporal Attention 通信量为 2*BTSh*(N-1)/N*2
Bytes, Cross Attention 的通信量为 2*B*TOKEN*h*(N-1)/N*2
. 一个 Transformer block 各有两次,因此一个 block 总的通信量为 2*B(TS+TOKEN)h*(N-1)/N*2*2
Bytes. spatial 和 temporal block 之间需要一次 All2All 将切分的序列维度在空间和时间维度上面转换。
理论通信量如下
shape/method | Megatron(GB) | 2Dimension(GB) | DS-Ulysses(GB) | Ring-Attention(GB) |
---|---|---|---|---|
204_640_360 | 80.124 | 34.892 | 16.0538 | 25.001 |
408_640_360 | 160.248 | 69.716 | 32.078 | 50.002 |
51_1280_720 | 78.382 | 34.135 | 15.705 | 24.468 |
102_1280_720 | 156.764 | 68.202 | 31.382 | 48.936 |
Computation Disassembly
Megatron
Operation | IShape | WShape | OShape | Comp | 204.640_360(GFLOPs) | Utilization(%) | Latency(ms) |
---|---|---|---|---|---|---|---|
RMSNorm | (B,S,T,h) | (h, ) | (B,S,T,h) | 4BSh | 0.473 | 19.260 | 2.466 |
t2i_modulate | (B,S,T,h) | (2, h) | (B,S,T,h) | 2BSh | 0.237 | 9.635 | 2.466 |
QKV_Linear | (B,S,T,h) | (h,3h/N) | (B,S,T,3h/N) | 6BShh/N | 51.614 | 85.148 | 0.4694 |
RMSNorm(Q) | (BT,NA/N,S,HA) | (1,HA) | (BT,NA/N,S,HA) | 4BSh/N | 0.0299 | 18.480 | 0.163 |
RMSNorm(K) | (BT,NA/N,S,HA) | (1,HA) | (BT,NA/N,S,HA) | 4BSh/N | 0.0290 | 18.480 | 0.163 |
RoPE(Q)+scale | (BT,NA/N,S,HA) | (1,S) | (BT,NA/N,S,HA) | (3+1)BSh/N | 0.0299 | 90.602 | 0.278 |
RoPE(K) | (BT,NA/N,S,HA) | (1,HA) | (BT,NA/N,S,HA) | 4BSh/N | 0.0199 | ||
Q@K.T | (BT,NA/N,S,HA) | (BT, HA/N, HA, S) | (BT,NA/N,S,S) | 2BThSS/N | 13.621 | ||
Softmax+DropOut | (BT,NA/N,S,S) | None | (BT,NA/N,S,S) | 6BT(NA)SS/N | 0.037 | ||
Score@V | (BT,NA/N,S,S) | (BT,NA/N,S,HA) | (BT,NA/N,S,HA) | 2BThSS/N | 13.621 | ||
O_Linear | (B, T, S, h/N) | (h/N, h) | (B,T,S,h) | 2BThh/N | 17.205 | 21.8477 | 0.610 |
Gate_ResAdd | (B,T,S,h) | (h,) | (B,T,S,h) | 2BSh | 0.237 | 9.635 | 2.4668 |
Q_Linear | (B,S,T,h) | (h,h/N) | (B,T,S,h/N) | 2BThh/N | 17.056 | 85.148 | 0.156 |
scale | (B,NA/N,S,T,HA) | None | (B,NA/N,S,T,HA) | BTh/N | 0.007 | ||
KV_Linear | (B,TOKEN,h) | (h,2h/N) | (B,TOKEN,2h/N) | 4B(TOKEN)h/N | 0.188 | 8.993 | 0.016 |
Q@K.T | (B,NA/N,TS,HA) | (B,NA/N,HA,TOKEN) | (B,NA/N,TS,TOKEN) | 2BTh(TOKEN)/N | 8.883 | 52.696 | 0.286 |
Softmax+DropOut | (B,NA/N,TS,TOKEN) | None | (B,NA/N,TS,TOKEN) | 6BT(NA)S(TOKEN)/N | 0.0116 | ||
Score@V | (B,NA/N,TS,TOKEN) | (B, NA/N, TOKEN, HA) | (B,NA/N,S,T,HA) | 2BTh(TOKEN)/N | 8.883 | ||
O_Linear | (B,T,S,h/N) | (h/N,h) | (B,T,S,h) | 2BThh/N | 17.056 | 21.847 | 0.610 |
Gate_ResAdd | (B,T,S,h) | (h,) | (B,T,S,h) | 2BSh | 0.237 | 9.635 | 2.470 |
RMSNorm | (B,S,T,h) | (h,) | (B,S,T,h) | 4BSh | 0.474 | 19.260 | 2.466 |
t2i_modulate | (B,S,T,h) | (2,h) | (B,S,T,h) | 2BSh | 0.237 | 9.635 | 2.466 |
FFN_Linear1 | (B,TS,h) | (h,4h/N) | (B,T,S,4h/N) | 8BThh/N | 68.225 | 98.641 | 0.740 |
GeLU | (B,TS,4h/N) | None | (B,TS,4h/N) | 2BTh/N | |||
FFN_Linear2 | (B,TS,4h/N) | (4h/N,h) | (B,T,S,h) | 8BThh/N | 68.225 | 81.220 | 0.656 |
Two Dimension
Operation | IShape | WShape | OShape | Comp | 204_640_360(GFLOPs) | Utilization(%) | Latency(ms) |
---|---|---|---|---|---|---|---|
RMSNorm | (B/Nx,ST,h) | (h,) | (B/Nx,ST,h) | 4BSt/hNx | 0.237 | 19.148 | 1.237 |
t2i_modulate | (B/Nx,ST,h) | (2,h) | (B/Nx,ST,h) | 2BSt/hNx | 0.118 | 9.581 | 1.236 |
QKV_Linear | (B/Nx,ST,h) | (h,3h/Ny) | (B/Nx,ST,3h/Ny) | 6BSt/h/N | 51.614 | 90.435 | 0.442 |
RMSNorm(Q) | (BT/Nx,NA/Ny,S,HA) | (1,HA) | (BT/Nx,NA/Ny,S,HA) | 4BSt/hN | 0.0299 | 18.480 | 0.163 |
RMSNorm(K) | (BT/Nx,NA/Ny,S,HA) | (1,HA) | (BT/Nx,NA/Ny,S,HA) | 4BSt/hN | 0.0299 | 18.480 | 0.163 |
RoPE(Q)+scale | (BT/Nx,NA/Ny,S,HA) | (1,S) | (BT/Nx,NA/Ny,S,HA) | (3+1)BSt/hN | 0.0299 | 94.684 | 0.251 |
RoPE(K) | (BT/Nx,NA/Ny,S,HA) | (1,HA) | (BT/Nx,NA/Ny,S,HA) | 4BSt/hN | 0.0199 | ||
Q@K.T | (BT/Nx,NA/Ny,S,HA) | (BT/Nx,NA/Ny,HA,S) | (BT/Nx,NA/Ny,S,HA) | 2BThS/N | 13.621 | ||
Softmax+DropOut | (BT/Nx,NA/Ny,S,S) | None | (BT/Nx,NA/Ny,S,S) | 6BT(NA)S/N | 0.0370 | ||
Score@V | (BT/Nx,NA/Ny,S,S) | (BT/Nx,NA/Ny,S,HA) | (BT/Nx,NA/Ny,S,HA) | 2BThS/N | 13.621 | ||
O_Linear | (B/Nx,TS,h/Ny) | (h/Ny,h) | (B/Nx,TS,h) | 2BThS/hN | 17.056 | 41.322 | 0.322 |
Gate_ResAdd | (B/Nx,TS,h) | (h,) | (B/Nx,TS,h) | 2BSt/hNx | 0.118 | 9.582 | 1.236 |
Q_Linear | (B/Nx,ST,h) | (h,/Ny) | (B/Nx,TS,h/Ny) | 2BTS/h/N | 17.056 | 90.435 | 0.281 |
scale | (B/Nx,NA/Ny,ST,HA) | None | (B/Nx,NA/Ny,ST,HA) | BTS/hN | 0.007 | ||
KV_Linear | (B/Nx,TOKEN,h) | (h,2h/Ny) | (B,TOKEN,2h/Ny) | 4B(TOKEN)h/N | 0.188 | 0.0184 | 15.758 |
Q@K.T | (B/Nx,NA/Ny,TS,HA) | (B/Nx,NA/Ny,HA,TOK) | (B/Nx,NA/Ny,TS,TOK) | 2BThS(TOKEN)/N | 8.883 | 53.470 | 0.273 |
Softmax+DropOut | (B/Nx,NA/Ny,TS,TOK) | None | (B/Nx,NA/Ny,TS,TOK) | 6BT(NA)S(TOKEN)/N | 0.012 | ||
Score@V | (B/Nx,NA/Ny,TS,TOK) | (BT/Nx,NA/Ny,TOK) | (B/Nx,NA/Ny,TS,HA) | 2BThS(TOKEN)/N | 8.883 | 53.470 | 0.273 |
O_Linear | (B/Nx,TS,h/Ny) | (h/Ny,h) | (B/Nx,TS,h) | 2BTS/h/N | 17.056 | 98.745 | 2.159 |
Gate_ResAdd | (B/Nx,TS,h) | (h,) | (B/Nx,TS,h) | 2BTS/hNx | 0.118 | 9.635 | 2.470 |
RMSNorm | (B/Nx,ST,h) | (h,) | (B/ST,h) | 4BSt/hNx | 0.237 | 19.260 | 2.461 |
t2i_modulate | (B/Nx,ST,h) | (2,h) | (B/Nx,ST) | 2BSt/hNx | 0.118 | 9.635 | 2.470 |
FFN_Linear1 | (B/Nx,TS,h) | (h,4h/Ny) | (B/Nx,TS,4h/Ny) | 8BTS/hN | 68.225 | 98.687 | 0.740 |
GeLU | (B/Nx,TS,4h/Ny) | None | (B/Nx,TS,4h/Ny) | 2BTS/4hNx | 68.225 | ||
FFN_Linear2 | (B/Nx,TS,4h/Ny) | (4h/Ny,h) | (B/Nx,TS,h) | 8BTS/hN | 68.225 | 97.161 | 0.533 |
DeepSpeed-Ulysses
Operation | IShape | WShape | OShape | Comp | 204_640_360(GFLOPs) | Utilization(%) | Latency(ms) |
---|---|---|---|---|---|---|---|
RMSNorm | (B,ST/N,h) | (h,) | (B,ST/N,h) | 4BSt/hN | 0.237 | 18.480 | 0.163 |
t2i_modulate | (B,ST/N,h) | (2,h) | (B,ST/N,h) | 2BSt/hN | 0.119 | 9.2592 | 0.163 |
QKV_Linear | (BT,S/N,h) | (h,3h) | (BT,S/N,3h) | 6BTS/hN | 51.614 | 83.026 | 0.486 |
RMSNorm(Q) | (BT,NA,S/N,HA) | (1,HA) | (BT,NA,S/N,HA) | 4BSt/hN | 0.0299 | 18.480 | 0.163 |
RMSNorm(K) | (BT,NA,S/N,HA) | (1,HA) | (BT,NA,S/N,HA) | 4BSt/hN | 0.0299 | 18.480 | 0.163 |
RoPE(Q)+scale | (BT,NA,S/N,HA) | (1,S/N) | (BT,NA,S/N,HA) | (3+1)BSt/hN | 0.0299 | 90.602 | 0.278 |
RoPE(K) | (BT,NA,S/N,HA) | (1,S/N) | (B,NA,ST/N,HA) | 3BSt/hN | 0.0199 | ||
Q@K.T | (BT,NA/N,S,HA) | (BT,NA/N,HA,S) | (BT,NA/N,S) | 2BThS/N | 13.621 | ||
Softmax+DropOut | (BT,NA/N,S) | None | (BT,NA/N,S) | 6BT(NA)S/N | 0.0370 | ||
Score@V | (BT,NA/N,S) | (BT,NA/N,S,HA) | (BT,NA/N,S,HA) | 2BThS/N | 13.621 | ||
O_Linear | (B,TS/N,h) | (h,) | (B,TS/N,h) | 2BTS/hN | 17.2045884375 | 83.025511 | 0.161891 |
Gate_ResAdd | (B,TS/N,h) | (h,) | (B,TS/N,h) | 2BTS/hN | 0.119 | 9.259260 | 0.163147 |
Q_Linear | (B,TS/N,h) | (h,) | (B,TS/N,h) | 2BTS/hN | 17.056 | 83.096255 | 0.16035 |
scale | (B,NA,ST/N,HA) | None | (B,NA,ST/N,HA) | BTS/hN | 0.007 | ||
KV_Linear | (B,TOKEN/N,h) | (h,2h) | (B,TOKEN/N,2h) | 4B(TOKEN)h/N | 0.188 | 6.78544 | 0.02163 |
Q@K.T | (B,NA/N,TS,HA) | (B,NA/N,HA,TOKEN) | (B,NA/N,TS,TOKEN) | 2BThS(TOKEN)/N | 8.883 | 52.696 | 0.286 |
Softmax+DropOut | (B,NA/N,TS,TOKEN) | None | (B,NA/N,TS,TOKEN) | 6BT(NA)S(TOKEN)/N | 0.012 | ||
Score@V | (B,NA/N,TS,TOKEN) | (BT,NA/N,TOKEN,HA) | (B,NA/N,TS,HA) | 2BThS(TOKEN)/N | 8.883 | 52.670 | 0.286 |
O_Linear | (B,TS/N,h) | (h,) | (B,TS/N,h) | 2BTS/hN | 17.056 | 83.096 | 0.160 |
Gate_ResAdd | (B,TS/N,h) | (h,) | (B,TS/N,h) | 2BTS/hN | 0.118 | 9.260 | 0.163 |
RMSNorm | (B,TS/N,h) | (h,) | (B,TS/N,h) | 4BSt/hN | 0.237 | 18.480 | 0.163 |
t2i_modulate | (B,TS/N,h) | (2,h) | (B,TS/N,h) | 2BSt/hN | 0.119 | 9.260 | 0.163 |
FFN_Linear1 | (B,TS/N,h) | (h,4h) | (B,TS/N,4h) | 8BTS/hN | 68.225 | 98.641 | 0.750 |
GeLU | (B,TS/N,4h) | None | (B,TS/N,4h) | 2BTS/4hN | |||
FFN_Linear2 | (B,TS/N,4h) | (4h,) | (B,TS/N,h) | 8BTS/hN | 68.225 | 88.560 | 0.602 |
分层的各个计算用时占比和利用率的示意图如下所示,可以看出 Vector 用时占主要部分,由于 Megatron 无法切分导致用时最长。
Ring Attention
MLIR Visitor
IRVisitor 类是一个基于 JAX 的 MLIR 结构遍历器,设计用于递归地遍历 MLIR IR 的各种节点(如操作、区域、块等)。它是一个基类,提供了通用的访问逻辑,允许子类通过重写特定方法来处理不同的节点类型。以下是对其工作原理的详细解释,重点说明它如何遍历整个 MHLO. 核心方法 visit 是遍历的入口方法,接收一个节点(node)作为参数。根据 node 的类型,动态构造访问器方法名:
- 如果 node 是 ir.Operation(一个具体的 MLIR 操作),方法名设为 visit_operation。
- 如果 node 是 ir.OpView(操作的视图,通常是特定操作的封装),解析其名称,尝试匹配特定操作类型(如 visit_add),如果没有匹配则回退到 visit_opview。
- 如果 node 是 ir.Region(表示操作中的一个代码区域),方法名设为 visit_region。
- 如果 node 是 ir.Block(区域中的一个基本块),方法名设为 visit_block。
- 使用 getattr 获取对应的方法,如果没有定义特定方法,则调用 generic_visit 作为默认行为.
generic_visit(self, node)
是当没有特定访问器方法时调用的通用方法。
TFlops Visitor
TFLopsIRVisitor 重写了基类 IRVisitor 中的部分方法,针对特定 MHLO 操作计算 TFLOPS:
visit_dot
: 处理 dot 操作。
获取输入张量的形状:- lhs_shape:左操作数(node.lhs)的形状,使用 ir.RankedTensorType 解析。
- result_shape:输出张量(node.result)的形状。
- 确定收缩维度(contract_dim):假设 dot 操作的收缩维度是 lhs_shape 的最后一个维度(-1),即两个张量相乘时对齐的维度。
计算 TFLOPS:2 * contract_dim * np.prod(np.array(result_shape))
. 将元组("dot", dot_tflops, node.location, node)
添加到 self.op_tflops. 递归调用 self.generic_visit(node) 继续遍历该节点的子结构(如果有)。
visit_dot_general
:处理 dot_general 操作。
获取输入和输出张量的形状:- lhs_shape:左操作数(node.lhs)的形状。
- result_shape:输出张量(node.result)的形状。
- 获取收缩维度信息:从
node.attributes
中提取dot_dimension_numbers
属性,解析为mhlo.DotDimensionNumbers
对象。 - contract_dims:左操作数的收缩维度索引。
- contract_size:通过 np.prod(np.array(lhs_shape)[contract_dims]) 计算收缩维度的总元素数。
计算 TFLOPS:2 * contract_size * np.prod(np.array(result_shape))
. 递归调用 self.generic_visit(node) 继续遍历子结构。
为了确定一个 Transformer block 中的 dot 和 dot_general 操作分别对应于什么,我们需要一个 RMSCollector
,当遇到 rsqrt
操作时,visit_rsqrt
方法通过 _parse_loc_lineno
提取行号,并添加到 RMSCollector.rms_locs
用于临时存储 RMS Norm 的行号。根据 rms_locs 分割为 spt_qkv_ranges、spt_attn_ranges、ffn_ranges. 遍历 IR,遇到 dot 或 dot_general 时,根据行号匹配到特定块,更新计数器。当计数器满足条件,记录块并重置状态。输出 attention_blocks 和 ffn_blocks 包含所有匹配的块。
1 |
|
Vector 计算量则是通过统计 add, subtract, multiply, divide, rsqrt, negate & exponential,操作对应的元素个数即为计算量。
Comm Visitor
CommIRVisitor
继承 IRVisitor 基类,专门用于遍历 MHLO 中的通信操作,计算每个通信操作(communication op)的通信量(comm volume)和通信延迟(comm latency)。支持多种集体通信操作(如 all_reduce、all_gather、reduce_scatter 和 all_to_all)CommIRVisitor
为每种通信操作实现了特定的 visit 方法,计算通信量和延迟。_get_ring_comm_stastics
用于为环形通信模式计算集体通信操作的通信量和延迟。通过节点的 replica_groups
属性获取本次通信的设备数 num_device
,环形通信中每个设备与除自身外的其他设备通信的平均比例为 (num_devices - 1) / num_devices
. 前三种通信操作都可以使用 Ring 模式,需要注意由于 AllReduce 实际上是 ReduceScatter + AllGather,因此计算通信量的时候要乘以 2. 环状通信量计算公式为
- 计算频率(freqs)
函数首先计算一组频率,用于生成正弦和余弦波。频率的计算基于指数衰减:
定义half=dim/2,即嵌入维度的一半。生成一个从 0 到 half-1 的序列:torch.arange(start=0,end-=half)
频率公式为:
Result
与单个 A100 以及采用 DS-Ulysses 并行策略的单机 8 卡 A100 的推理结果做对比。由上述计算量拆解可以看出 Vector 操作占主导部分,由于 Megatron 无法切分序列维度导致了重复计算,而 TwoDimension 可以切分 batch_size 维度,DS_Ulysses 则是可以完全切分序列维度。RingAttention 由于本身 Attention 计算的序列长度不是很长,再在计算时进行进一步切分导致了 FlashAttention 的利用率极低,导致用时很长。
分辨率 | Backbone 理论计算量 | device | 设备数量 | 并行方法 | latency(s) STDT3 | alloc_mem(GB) | comm_volume(MB) | 利用率 |
---|---|---|---|---|---|---|---|---|
204_640_360 | 260.978/0.795 | A100 | 1 | xxx | 99.00/115.43 | 0.253 | ||
408_640_360 | 523.479/1.608 | A100 | 1 | xxx | 202.3/250.66 | 0.248 | ||
51_1280_720 | 292.026/1.160 | A100 | 1 | xxx | 104.24/104.13 | 0.269 | ||
102_1280_720 | 584.284/2.324 | A100 | 1 | xxx | 206.92/206.66 | 0.271 | ||
204_640_360 | 260.978/0.795 | A100 | 8 | DS_Ulysses | 13.76/28.82 | 0.124 | ||
408_640_360 | 523.479/1.608 | A100 | 8 | DS_Ulysses | 26.61/51.62 | 0.178 | ||
51_1280_720 | 292.026/1.160 | A100 | 8 | DS_Ulysses | 25.26/34.56 | 0.139 | ||
102_1280_720 | 584.284/2.324 | A100 | 8 | DS_Ulysses | 36.5/53.19 | 0.192 | ||
204_640_360 | 326,227/1.667 | TX8 | 16x(4x4) | Megatron | 73.38 | 7.021 | 143956.143 | 0.088 |
408_640_360 | 654,693/3.26 | TX8 | 16x(4x4) | Megatron | 145.71 | 13.807 | 267921.362 | 0.089 |
51_1280_720 | 375,531/2.218 | TX8 | 16x(4x4) | Megatron | 72.18 | 24.126 | 140835.66 | 0.085 |
102_1280_720 | 751,315/4.442 | TX8 | 16x(4x4) | Megatron | 142.95 | 48.016 | 281662.397 | 0.086 |
204_640_360 | 495,591/1.5 | TX8 | 16x(4x4) | Two_Dimension | 50.43 | 3.749 | 107759.849 | 0.129 |
408_640_360 | 992,661/3.021 | TX8 | 16x(4x4) | Two_Dimension | 100.59 | 7.26 | 215253.543 | 0.13 |
51_1280_720 | 530,634/1.945 | TX8 | 16x(4x4) | Two_Dimension | 49.77 | 12.19 | 105423.029 | 0.138 |
102_1280_720 | 1061,151/3.9 | TX8 | 16x(4x4) | Two_Dimension | 99.03 | 24.141 | 210579.804 | 0.139 |
204_640_360 | 584,310/0.431 | TX8 | 16x(4x4) | DS_Ulysses | 36.93 | 3.749 | 103923.135 | 0.1 |
408_640_360 | 1130,119/0.864 | TX8 | 16x(4x4) | DS_Ulysses | 73.65 | 7.26 | 103152.307 | 0.1 |
51_1280_720 | 588,764/0.446 | TX8 | 16x(4x4) | DS_Ulysses | 36.84 | 12.19 | 116269.846 | 0.11 |
102_1280_720 | 1177,769/0.892 | TX8 | 16x(4x4) | DS_Ulysses | 72.42 | 24.141 | 202365.35 | 0.11 |
204_640_360 | 584,310/0.431 | TX8 | 16x(4x4) | Ring_Attention | 43.70 | |||
408_640_360 | 1130,119/0.864 | TX8 | 16x(4x4) | Ring_Attention | 86.37 | |||
51_1280_720 | 588,764/0.446 | TX8 | 16x(4x4) | Ring_Attention | 43.69 | |||
102_1280_720 | 1177,769/0.892 | TX8 | 16x(4x4) | Ring_Attention | 84.38 |