1 Star 0 Fork 2

berryz2007/Transformer Demo

forked from Hauk Zero/Transformer Demo 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pe.py 1.06 KB
一键复制 编辑 原始数据 按行查看 历史
Hauk Zero 提交于 2024-07-26 17:22 . add all
import math
import torch
from torch import nn
class PositionEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.5):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(max_len).unsqueeze(1)
# 先用 exp 计算可以加速计算过程
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(1e4) / d_model))
# sin: 2i cos: 2i + 1
pe[ :, 0::2 ] = torch.sin(position * div_term)
pe[ :, 1::2 ] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# 设置为缓冲区, 以便在不通设备间传输模型时保持其状态
self.register_buffer('pe', pe)
def forward(self, x):
x += self.pe[ :, :x.shape[ 1 ], :x.shape[ 2 ] ]
return self.dropout(x)
if __name__ == '__main__':
d_model = 4
max_len = 9
x = torch.randn(1, max_len, d_model)
pe = PositionEncoding(d_model, max_len=max_len)
print(f"x = {x}")
print(f"pe(x) = {pe(x)}")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yoours/transformer-demo.git
git@gitee.com:yoours/transformer-demo.git
yoours
transformer-demo
Transformer Demo
master

搜索帮助