← 返回首页
SKLearn基础教程(六)
发表时间:2023-06-07 00:24:08
线性回归

1.什么是线性回归

其实回归算法是相对分类算法而言的,与我们想要预测的目标变量y的值类型有关。如果目标变量y是分类型变量,如预测用户的性别(男、女),预测月季花的颜色(红、白、黄……),预测是否患有肺癌(是、否),那我们就需要用分类算法去拟合训练数据并做出预测;如果y是连续型变量,如预测用户的收入(4千,2万,10万……),预测员工的通勤距离(500m,1km,2万里……),预测患肺癌的概率(1%,50%,99%……),我们则需要用回归模型。

线性回归过程主要解决的是如何通过样本获取最佳的拟合线,最常用的方法是最小二乘法。在古代,“平方”的称谓为“二乘”,故得最小二乘法。

2.数据拟合法和插值法

一种数学优化技术,通过最小化残差的平方和寻找数据的最佳函数匹配。在数理统计中,残差是指实际观察值与估计值之间的差。力求总的拟合误差(即总残差)达到最小。

3.最小二乘法计算过程

最小二乘法也被称作最小平方法,最常用的是普通最小二乘法(Ordinary Least Square),它是一种数学中的优化方法,试图找到一个或一组估计值,使得实际值与估计值的尽可能相似,距离最小,目的是通过已有的数据来预测未知数据。一般通过一条多元一次的直线方程,在二维坐标中即二元一次方程,例如在二维坐标中,有非常多的点分散在其中,试图绘制一条直线,使得这些分散的点到直线上的距离最小。这里的距离最小并非点到直线的垂直距离最短,而是点到直接的y轴距离最短,即通过该点并与y轴平行的直线,点到该y轴平行线与直线交点的距离最短,如下图所示。

最小二乘法的核心思想是通过最下化误差的平方和,试图找到最可能的函数方程 。假设在二维坐标系中存在五个数据点(10,20)、(11,23)、(12,25)、(13,27)、(14,26),希望找出一条该五个点距离最短的直线,根据二元一次方程:y=ax+b

因此,将五个点分别带入该二元方程得到如下:

20=10a+b
23=11a+b
25=12a+b
27=13a+b
26=14a+b

由于最小二乘法是尽可能使得等号两边的方差值最小,因此 :

因此求最小值即可通过对S(a,b)求偏导数获得,并使得一阶倒数的值为0,则:

即得到关于求解未知变量a、b的二元一次方程:

通过计算上述二元一次方法即得到a=0.0243,b=24.1708。因此,在上述五个点中,通过最小二乘法得到直线方程:y=0.0243x + 24.1708是使得五个点到该直线距离最小的直线。

4.案例

1).预测学生身高体重为例

# 拟合曲线
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from scipy.optimize import leastsq

# 样本数据
# 身高数据
Xi = np.array([162, 165, 159, 173, 157, 175, 161, 164, 172, 158])
# 体重数据
Yi = np.array([48, 64, 53, 66, 52, 68, 50, 52, 64, 49])
# 需要拟合的函数func()指定函数的形状
def func(p, x):
    k, b = p
    return k*x + b


# 定义偏差函数,x,y为数组中对应Xi,Yi的值
def error(p, x, y):
    return func(p, x) - y

# 设置k,b的初始值,可以任意设定,经过实验,发现p0的值会影响cost的值:Para[1]
p0 = [1, 20]

# 把error函数中除了p0以外的参数打包到args中,leastsq()为最小二乘法函数
Para = leastsq(error, p0, args=(Xi, Yi))

print(Para)
# 读取结果
k, b = Para[0]
print('k=', k, 'b=', b)

# 画样本点
plt.figure(figsize=(8, 6))
plt.scatter(Xi, Yi, color='red', label='Sample data', linewidth=2)

# 画拟合直线
x = np.linspace(150, 180, 80)
y = k * x + b

# 绘制拟合曲线
plt.plot(x, y, color='blue', label='Fitting Curve', linewidth=2)
plt.legend()  # 绘制图例

plt.xlabel('Height:cm', fontproperties='simHei', fontsize=12)
plt.ylabel('Weight:Kg', fontproperties='simHei', fontsize=12)

plt.show()

2).计算残差

# 计算残差
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from scipy.optimize import leastsq
from statsmodels.graphics.api import qqplot

# 样本数据
# 身高数据
Xi = np.array([162, 165, 159, 173, 157, 175, 161, 164, 172, 158])
# 体重数据
Yi = np.array([48, 64, 53, 66, 52, 68, 50, 52, 64, 49])

# 定义变量
xy_res=[]
# 定义计算残差函数
def residual(x,y):
    res = y - (0.4211697*x-8.2883026)               # 计算残差
    return res                                      # 返回残差

# 循环读取残差
for d in range(0,len(Xi)):
    res = residual(Xi[d], Yi[d])
    xy_res.append(res)

print(xy_res)
# 计算残差平方和,和越小表明拟合的情况越好
xy_res_pingfangsum = np.dot(xy_res,xy_res)
print(xy_res_pingfangsum)

# 如果数据拟合模型效果好,残差应该遵从正态分布(0,d*d),d表示残差

# 画样本点
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)           # 添加一个子图
fig = qqplot(np.array(xy_res),line='q',ax=ax)  # 设置参数

plt.show()