Lucas Zhao

执行力才是破除迷茫的关键

0%

Understanding Positional Encoding

Transformer中引入的位置编码,直接从代码解读难以理解具体操作,通过轻量数据的演算加深理解。

1. timestep_embedding in Unet

源码

1
2
3
4
5
6
7
8
9
10
11
12
def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1

half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb

计算演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# def get_timestep_embedding(timesteps, embedding_dim):
## Hypothesis
timesteps = torch.tensor([1, 2, 3, 4, 5]) # 输入的时间步
embedding_dim = 6 # 嵌入维度
## Hypothesis

assert len(timesteps.shape) == 1 # 检查 timesteps 是否为一维

half_dim = embedding_dim // 2 # half_dim = 6 // 2 = 3

emb = math.log(10000) / (half_dim - 1) # emb = log(10000) / 2 ≈ 4.60517
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# torch.arange(half_dim) 生成 tensor([0, 1, 2])
# emb = torch.exp(torch.tensor([0, 1, 2]) * -4.60517)
# emb = tensor([1.0000, 0.0100, 0.0001])

emb = emb.to(device=timesteps.device) # 假设 timesteps 在 CPU 上,所以 emb 也在 CPU 上

emb = timesteps.float()[:, None] * emb[None, :]
# timesteps.float()[:, None] 将 timesteps 扩展为列向量 tensor([[1.], [2.], [3.], [4.], [5.]])
# emb[None, :] 将 emb 扩展为行向量 tensor([[1.0000, 0.0100, 0.0001]])
# emb = tensor([[1.0000, 0.0100, 0.0001],
# [2.0000, 0.0200, 0.0002],
# [3.0000, 0.0300, 0.0003],
# [4.0000, 0.0400, 0.0004],
# [5.0000, 0.0500, 0.0005]])

emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# torch.sin(emb) = tensor([[ 0.8415, 0.0100, 0.0001],
# [ 0.9093, 0.0200, 0.0002],
# [ 0.1411, 0.0300, 0.0003],
# [-0.7568, 0.0400, 0.0004],
# [-0.9589, 0.0500, 0.0005]])
# torch.cos(emb) = tensor([[ 0.5403, 0.9999, 1.0000],
# [-0.4161, 0.9998, 1.0000],
# [-0.9900, 0.9996, 1.0000],
# [-0.6536, 0.9992, 1.0000],
# [ 0.2837, 0.9988, 1.0000]])
# emb = tensor([[ 0.8415, 0.0100, 0.0001, 0.5403, 0.9999, 1.0000],
# [ 0.9093, 0.0200, 0.0002, -0.4161, 0.9998, 1.0000],
# [ 0.1411, 0.0300, 0.0003, -0.9900, 0.9996, 1.0000],
# [-0.7568, 0.0400, 0.0004, -0.6536, 0.9992, 1.0000],
# [-0.9589, 0.0500, 0.0005, 0.2837, 0.9988, 1.0000]])

if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# embedding_dim % 2 == 0,不需要填充

# 最终结果
emb = tensor([[ 0.8415, 0.0100, 0.0001, 0.5403, 0.9999, 1.0000],
[ 0.9093, 0.0200, 0.0002, -0.4161, 0.9998, 1.0000],
[ 0.1411, 0.0300, 0.0003, -0.9900, 0.9996, 1.0000],
[-0.7568, 0.0400, 0.0004, -0.6536, 0.9992, 1.0000],
[-0.9589, 0.0500, 0.0005, 0.2837, 0.9988, 1.0000]])