前言

逻辑回归 = 线性回归 + 激活函数

激活函数的功能: 1. 可以满足非线性因子 2. 将值映射到一个小范围内,减少梯度爆炸的可能

逻辑回归的问题: 1. 回归问题(解答题,根据历史情况和方程计算划分) 2. 分类问题(分类问题,根据已有进行划分)

sigmod 激活函数多用于二分类问题(公式: \(sigmoid(P)=\frac {1} {(1+e^{-P})}\)
softmax 激活函数多用于多分类问题

计算步骤

  1. 先求出计算出来的值:\[P=X@k\]
  2. 利用 sigmoid 激活函数映射出值域为 0~1 的 pre:\[pre=sig(P)=\frac {1} {(1+e^{-P})}\]
  3. 利用 pre 和 Label 分类标签计算 Loss:\[Loss=Label·log(pre)+(1-Label)·log(1-pre)\] 因为 pre 一直在 0~1 之间,所以 \(log (1-pre)\)\(log (pre)\) 是负数,所以这里 loss 是负数,但是为了更加方便计算,常常将 Loss 进行取反:\[Loss=-Label·log(pre)+(1-Label)·log(1-pre)\]
  4. 利用 Loss 对 k 进行影响:\[G=\frac{\partial Loss}{\partial P}=\frac{\partial Loss}{\partial pre}·\frac{\partial pre}{\partial P}=pre-Label\]\[\frac{\partial Loss}{\partial k}=X^T@G\] \[k=k-lr*\frac{\partial Loss}{\partial k}\]

注意

如果需要加上偏置值 b,需要进行更新:\[\frac{\partial Loss}{\partial b}=sum(G)\] \[b=b-lr*\frac{\partial Loss}{\partial b}\]

复现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np  

def sigmoid(x):
return 1/(1+np.exp(-x)) # 利用np对numpy的每个数都单独进行这样的操作


if __name__ == "__main__":

# [毛发长,腿长]
dogs = np.array([[8.9,12],[9,11],[10,13],[9.9,11.2],[12.2,10.1],[9.8,13],[8.8,11.2]],dtype = np.float32) # 0
cats = np.array([[3,4],[5,6],[3.5,5.5],[4.5,5.1],[3.4,4.1],[4.1,5.2],[4.4,4.4]],dtype = np.float32) # 1
labels = np.array([0]*7+[1]*7, dtype=np.int32).reshape(-1,1)
# print(label)

X = np.vstack((dogs,cats)) # np.vstack返回竖直堆叠后的数组
# print(X)

# np.random.normal能指定生成的数据的均值和方差
k = np.random.normal(0,1,size=(2,1)) # 这里规定均值为0,方差为1,2行1列
b = 0 # 偏置值
epoch = 1000
lr = 0.05

for e in range(epoch):
p = X @ k + b # 计算结果

pre = sigmoid(p) # pre的每个值都在0~1
loss = -np.mean(labels * np.log(pre) + (1-labels) * np.log((1-pre))) # loss是一个标量
G = pre - labels

delta_k = X.T @ G
delta_b = np.sum(G)

k = k - lr * delta_k
b = b - lr * delta_b

print(loss)

while True:
f1 = float(input("请输入毛发长:"))
f2 = float(input("请输入腿长:"))

text_x = np.array([f1,f2]).reshape(1,2)
p = sigmoid(text_x @ k + b)
if p>0.5:
print("类别:猫")
else:
print("类别:狗")

相关链接

bilibili
GitHub