返回

从零开始轻松学深度学习:用numpy实现神经网络训练

闲谈

  1. 深度学习基础知识

1.1 人工智能、机器学习、深度学习三者关系

人工智能(AI)是计算机科学的一个分支,其目标是开发能够像人类一样思考和行动的机器。机器学习(ML)是人工智能的一个子领域,它允许计算机在没有明确编程的情况下学习和改进。深度学习(DL)是机器学习的一个子领域,它使用人工神经网络来学习和改进。

1.2 神经元、突触、权重、偏置、激活函数

人工神经网络是由神经元组成的,每个神经元都模拟生物神经元的行为。神经元具有突触,突触是神经元之间的连接。每个突触都有一个权重,权重决定了突触的强度。神经元还具有偏置,偏置是一个常数,它决定了神经元的激活阈值。当神经元的输入信号超过阈值时,神经元就会被激活,并产生一个输出信号。

1.3 反向传播、梯度下降、优化器

反向传播是一种用于训练人工神经网络的算法。它通过计算神经网络的损失函数的梯度来更新神经网络的权重和偏置。梯度下降是一种优化算法,它使用梯度来找到损失函数的最小值。优化器是实现梯度下降算法的软件工具。

1.4 损失函数、准确率

损失函数是衡量神经网络输出与真实值之间差异的函数。准确率是衡量神经网络预测正确率的指标。

1.5 深度学习框架

深度学习框架是实现深度学习算法的软件平台。常用的深度学习框架包括TensorFlow、PyTorch和Keras。

2. 用Numpy实现神经网络训练

2.1 导入库

import numpy as np

2.2 创建神经网络

class NeuralNetwork:

    def __init__(self, input_size, hidden_size, output_size):
        # 初始化权重和偏置
        self.W1 = np.random.randn(input_size, hidden_size)
        self.b1 = np.zeros((hidden_size,))
        self.W2 = np.random.randn(hidden_size, output_size)
        self.b2 = np.zeros((output_size,))

    def forward(self, X):
        # 前向传播
        Z1 = np.dot(X, self.W1) + self.b1
        A1 = np.tanh(Z1)
        Z2 = np.dot(A1, self.W2) + self.b2
        A2 = np.tanh(Z2)
        return A2

    def backward(self, X, Y, A2):
        # 反向传播
        dZ2 = (A2 - Y) * (1 - A2**2)
        dW2 = np.dot(A1.T, dZ2)
        db2 = np.sum(dZ2, axis=0)
        dZ1 = np.dot(dZ2, self.W2.T) * (1 - A1**2)
        dW1 = np.dot(X.T, dZ1)
        db1 = np.sum(dZ1, axis=0)

        # 更新权重和偏置
        self.W1 -= 0.01 * dW1
        self.b1 -= 0.01 * db1
        self.W2 -= 0.01 * dW2
        self.b2 -= 0.01 * db2

    def train(self, X, Y, epochs=1000):
        # 训练神经网络
        for epoch in range(epochs):
            A2 = self.forward(X)
            self.backward(X, Y, A2)

    def predict(self, X):
        # 预测
        A2 = self.forward(X)
        return A2

2.3 使用神经网络

# 创建神经网络
neural_network = NeuralNetwork(2, 10, 1)

# 训练神经网络
neural_network.train(X, Y, epochs=1000)

# 预测
Y_pred = neural_network.predict(X)

# 评估准确率
accuracy = np.mean(np.round(Y_pred) == Y)
print("准确率:", accuracy)