返回

PyTorch寻最优w值:sin(w*x)^2分类奇偶数

python

使用PyTorch寻找sin(w*x)^2中的w值用于奇偶分类

利用神经网络逼近目标函数是机器学习中的常见任务。此问题中,任务是将整数分类为偶数或奇数,通过拟合函数sin(w*x)^2,并寻找合适的参数w。一个合适的w可以使得该函数在输入为偶数时接近0,输入为奇数时接近1。此目标函数具备周期性,存在多个解,目标为找到接近π/2的值。

问题分析

尝试使用简单的神经网络结构,包括一个无偏置的线性层和一个激活函数(sin(x)^2),目的是找到最优的 w,让 sin(w*x)^2 函数的结果能够将输入数据成功区分奇偶数。实际的实践中,使用了均方误差损失函数(MSELoss),并且使用随机梯度下降法(SGD)进行优化。然而,模型似乎没有收敛到一个正确的解,这在参数图上的表现为权重值震荡,而不是稳定到一个特定值。

可能出现的原因是损失函数的梯度可能非常平缓,或者优化器选择了不适合这种高度非线性问题的策略。另一个因素可能是选择了不合适的学习率,以及没有利用梯度下降优化算法的一些策略,导致训练不稳定。

解决方案与代码示例

以下针对问题原因提供了几种可能的解决方案。

调整学习率

初始的学习率 0.0004 可能太小,导致训练过程缓慢,甚至难以脱离局部最小值。增加学习率或者使用动态调整的学习率可能加速训练进程,并且提高找到合适参数的概率。一种方法是使用 torch.optim.lr_scheduler 模块中的学习率调度器,例如 ReduceLROnPlateau,它会在验证集loss不再下降时降低学习率。

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1, 1, bias=False)
        with torch.no_grad():
            self.fc1.weight.data.fill_(1.5)

    def forward(self, x):
        x = self.fc1(x)
        return torch.sin(x)**2

# 数据准备保持不变,省略
def generate_data(size):
    x = np.random.randint(0, size, size)
    return x.astype(float), (x % 2).astype(float)
train_x, train_y = generate_data(1000)
val_x, val_y = generate_data(1000)
train_x = torch.tensor(train_x, dtype=torch.float32).reshape(-1, 1)
train_y = torch.tensor(train_y, dtype=torch.float32).reshape(-1, 1)
val_x = torch.tensor(val_x, dtype=torch.float32).reshape(-1, 1)
val_y = torch.tensor(val_y, dtype=torch.float32).reshape(-1, 1)
net = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01) # Increase lr

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True) # Adjust learning rate when val loss is not decreasing
weights = []

for epoch in range(200): #增加epoch次数
    net.train()
    optimizer.zero_grad()
    output = net(train_x)
    loss = criterion(output, train_y)
    loss.backward()
    optimizer.step()
    net.eval()
    with torch.no_grad():
        val_output = net(val_x)
        val_loss = criterion(val_output, val_y)
    scheduler.step(val_loss)

    if epoch % 10 == 0: # 减少打印频率
        print(f"Epoch {epoch}  Loss: {loss.item():.4f}  Val Loss: {val_loss.item():.4f}")
        w = net.fc1.weight.item()
        print(f"Weight: {w:.4f} (target: {np.pi/2:.4f})")
        print("---")

    weights.append(net.fc1.weight.item())

import matplotlib.pyplot as plt
plt.plot(weights)
plt.plot([np.pi/2]*len(weights), linestyle='--') # Use dashed lines for comparison target value.
plt.title("Weights During Training")
plt.show()

# Test the model as described in the origin code, but removed the final parameter's printing
test_numbers = np.arange(0, 1500, 1)
net.eval()
with torch.no_grad():
    for x in test_numbers:
        test_input = torch.tensor([[float(x)]], dtype=torch.float32)
        pred = net(test_input).item()
        print("✅" if (pred < 0.5) == (x % 2 == 0) else "❌", end="")
        if (x+1) % 60 == 0:
            print()

该代码调整了学习率为 0.01,增加了 epoch 次数,并且加入了动态学习率调整。
此举可以提高收敛速度和权重稳定性,防止权重震荡。此外使用 matplotlib 生成了权重变化曲线图。

使用不同的优化器

随机梯度下降(SGD)可能不适合这个任务,尝试 Adam 这种自适应的优化器。Adam 算法能自动调整每个参数的学习率,使得在参数空间的不同方向都有较好的收敛速度。它对学习率超参数的选择要求不高,比传统梯度下降更加稳健。

将优化器更改为 Adam:

# 模型初始化,保持不变
net = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.005) # 使用 Adam 并调高 learning rate

# 训练代码同调整学习率的方案
#...

使用 Adam 优化器时可以尝试不同的初始学习率,通常在0.0010.01 之间。

探索初始化方式

模型初始化也是重要的,需要尝试更加精确的初始化,使得初始权重更加接近 π/2。另外,可能初始化的时候就处于某个函数的局部最小值,导致模型无法逃出, 可以加入随机的微小扰动,打破对称。

对权值初始化进行以下操作:

# 在网络定义时添加扰动,将微小随机值添加到预设的值上
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1, 1, bias=False)
        with torch.no_grad():
            init_w =  1.5 + (torch.rand(1).item() - 0.5) * 0.005  # 添加小的随机扰动
            self.fc1.weight.data.fill_(init_w) #更加精确的初始化

    def forward(self, x):
        x = self.fc1(x)
        return torch.sin(x)**2

关于模型选择的考量

本问题中使用的简单模型具有一定的局限性。尽管可以通过找到特定的w来分离偶数和奇数,但这可能不是最鲁棒或者最优的方法。尝试其他结构模型(例如一个两层的神经网络,带有多个线性层和非线性激活函数),或许可以更加快速的训练模型达到目标精度,在找到特定参数之外的另一条思路。

其他建议

  • 对输入数据进行归一化或缩放可能会有所帮助,因为当前数据的绝对值很大。

总结

要使模型能够成功找到合适的w值,不仅需要对学习率和优化器进行微调,初始化权重,选择模型结构等等都会有一定的影响,针对性的修改才能获得最终想要的效果。