1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
|
timesteps = torch.tensor([1, 2, 3, 4, 5]) embedding_dim = 6
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
emb = tensor([[ 0.8415, 0.0100, 0.0001, 0.5403, 0.9999, 1.0000], [ 0.9093, 0.0200, 0.0002, -0.4161, 0.9998, 1.0000], [ 0.1411, 0.0300, 0.0003, -0.9900, 0.9996, 1.0000], [-0.7568, 0.0400, 0.0004, -0.6536, 0.9992, 1.0000], [-0.9589, 0.0500, 0.0005, 0.2837, 0.9988, 1.0000]])
|