可学习的权重代码
本文最后更新于 2024-10-08,文章内容可能已经过时。
class LearnableCoefficient(nn.Module):
def __init__(self):
super(LearnableCoefficient, self).__init__()
self.bias = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=True)
def forward(self, x):
out = x * self.bias
return out
class LearnableWeights(nn.Module):
def __init__(self):
super(LearnableWeights, self).__init__()
self.w1 = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
self.w2 = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
def forward(self, x1, x2):
out = x1 * self.w1 + x2 * self.w2
return out
本文是原创文章,采用 CC BY-NC-ND 4.0 协议,完整转载请注明来自 Titos
评论
匿名评论
隐私政策
你无需删除空行,直接评论以获取最佳展示效果