PLS 网上并没有找到严格的 多元高次线性回归 的严格证明, 但经过测试我验证了代码的可用性,一定情况下高次(即可包含如x*x的项)拟合r2更加高,但很多时候并不是如此,只是提供一种优化的可能,只需要改 polynomial 参数即可。代码PLS全部改自 matlab 代码,并自行加入画图、高次项的优化和r2的评价函数。该代码基本已是完整代码,只有 def Polynomial(self): 函数代码有缺陷,如有更好想法可以一起改进。
如要搬运请注明出处,并和我联系谢谢。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import matplotlib.pyplot as plt import pandas as pd import numpy as np from sklearn import preprocessing from numpy.matlib import repmat import csv from math import ceil class Linear: def __init__(self, dependent, document, polynomial=1): self.dependent = dependent self.document = document self.polynomial = polynomial self.openfile() self.n = len(self.df.columns) - self.dependent self.Polynomial() x0, y0, num, xishu, ch0, xish, sol = self.find() self.save(sol=sol) self.PLOT(ch0=ch0, num=num, x0=x0, y0=y0, xishu=xishu, xish=xish) def openfile(self): file_type = self.document.split(".")[-1] if file_type == "csv": self.df = pd.read_csv(self.document, encoding='GBK') elif file_type == "xlsx" or file_type == "xls": self.df = pd.read_excel(self.document) else: exit("Unknown file type") def Polynomial(self): if self.polynomial != 1: temp = self.df.iloc[:, -self.dependent:] self.df.drop(self.df.columns[-self.dependent:], axis=1, inplace=True) count = self.n count_begin = 0 for i in range(1, self.polynomial): count_end = count for k in range(self.n): for j in range(count_begin, count_end): name = "x" + str(k + 1) + str(j + 1) if i == 1 else "x" + str(k + 1) + self.df.columns[j][1:] count += 1 self.df[name] = self.df.iloc[:, j].mul(self.df.iloc[:, k]) count_begin = count_end for i in range(len(temp.columns)): self.df[temp.columns[i]] = temp.iloc[:, i] self.df.to_csv("changed.csv", encoding='GBK') def find(self): df = self.df df_matrix = np.array(df) mu = np.mean(df_matrix, axis=0) sig = np.std(df_matrix, axis=0) rr = df.corr() rr.to_csv("相关系数矩阵.csv", encoding='GBK') data = preprocessing.scale(df_matrix) m = self.dependent n = len(df.columns) - m self.n = n x0 = df_matrix[:, :n] y0 = df_matrix[:, n:] e0 = data[:, :n] f0 = data[:, n:] num = len(df.iloc[:, 0]) chg = np.identity(n) w = np.zeros([n, n]) w_star = np.zeros([n, n]) t = np.zeros([num, n]) ss = [] Q_h2 = [] press_i = [0 for i in range(num)] press = [0 for i in range(n)] flag = 0 for i in range(n): matrix = e0.T @ f0 @ f0.T @ e0 [val, vec] = np.linalg.eig(matrix) val = val.argsort() w[:, i] = vec[:, val[len(val) - 1]] w_star[:, i] = chg @ w[:, i] t[:, i] = e0 @ w[:, i] alpha = [e0.T @ t[:, i] / (t[:, i].T @ t[:, i])] chg = chg @ (np.identity(n) - w[:, i:i + 1] @ alpha) e0 = e0 - t[:, i:i + 1] @ alpha beta = np.linalg.pinv(np.c_[t[:, :i + 1], np.ones(num)]) @ f0 beta = np.delete(beta, (-1), axis=0) cancha = f0 - t[:, :i + 1] @ beta cancha = np.array([[cancha[i][j] ** 2 for j in range(len(cancha[i]))] for i in range(len(cancha))]) ss.append(cancha.sum()) for j in range(num): t1 = t[:, :i + 1] f1 = f0 she_t = t1[j:j + 1, :] she_f = f1[j:j + 1, :] t1 = np.delete(t1, j, axis=0) f1 = np.delete(f1, j, axis=0) beta1 = np.linalg.pinv(np.c_[t1, np.ones(num - 1)]) @ f1 beta1 = np.delete(beta1, (-1), axis=0) cancha = she_f - she_t @ beta1 cancha = np.array([[cancha[i][j] ** 2 for j in range(len(cancha[i]))] for i in range(len(cancha))]) press_i[j] = cancha.sum() press[i] = np.array(press_i).sum() if i > 0: Q_h2.append(1 - press[i] / ss[i - 1]) # print('Q_h2[{}] = {}'.format(i, (1 - press[i] / ss[i - 1]))) else: Q_h2.append(1) if Q_h2[i] < 0.0975: # print('Number of components proposedr = %d' % (i + 1)) print("Q_h2 = {}".format(Q_h2[-1])) r = i flag = 1 break if not flag: exit("Can't find") beta_z = np.linalg.pinv(np.c_[t[:, :r + 1], np.ones(num)]) @ f0 beta_z = np.delete(beta_z, (-1), axis=0) xishu = w_star[:, :r + 1] @ beta_z mu_x = mu[:n] mu_y = mu[n:] sig_x = sig[:n] sig_y = sig[n:] ch0 = [] for i in range(m): ch0.append(float(mu_y[i] - np.true_divide(mu_x, sig_x) * sig_y[i] @ xishu[:, i:i + 1])) xish = np.zeros([n, m]) for i in range(m): xish[:, i] = np.true_divide(xishu[:, i], sig_x.T) * sig_y[i] sol = np.r_[np.array([ch0]), xish] # 防止报错 nan_2_0 # x0, y0, num, xishu, ch0, xish, sol = map(lambda x: np.nan_to_num(x), [x0, y0, num, xishu, ch0, xish, sol]) return x0, y0, num, xishu, ch0, xish, sol def PLOT(self, ch0, num, x0, y0, xishu, xish): plt.rcParams['font.sans-serif'] = ['SimHei'] # 用黑体显示中文 plt.rcParams['axes.unicode_minus'] = False ch0 = repmat(ch0, num, 1) y_hat = ch0 + x0 @ xish y1max = y_hat.max(axis=0) y2max = y0.max(axis=0) ymax = np.r_[np.array([y1max]), np.array([y2max])].max(axis=0) for i in range(self.dependent): print("y{}: R^2 score = {}".format(i + 1, R2_func(y_hat[:, i], self.df.iloc[:, -self.dependent + i]))) plt.subplot(self.dependent, 2, i * 2 + 1) x = [-1, ceil(ymax[i])] plt.plot(x, x, '-') plt.plot(y_hat[:, i], y0[:, i], '*') plt.title("y{}".format(i + 1)) plt.subplot(self.dependent, 2, i * 2 + 2) x = np.arange(self.n) plt.bar(x, height=xishu[:, i].reshape([1, self.n], order='F')[0], width=0.5) plt.plot([0, self.n], [0, 0], "-") plt.title("y{}".format(i + 1)) plt.tight_layout() plt.savefig("verify.jpg") plt.show() def save(self, sol): sol = np.r_[[["y{}".format(i + 1) for i in range(self.dependent)]], sol] sol = np.c_[["dependent", "x0"] + list(self.df.columns[:-self.dependent]), sol] print(sol) with open("result.csv", "w", newline="") as file: writer = csv.writer(file) writer.writerows(sol) def R2_func(y_test, y): return 1 - ((y_test - y) ** 2).sum() / ((y.mean() - y) ** 2).sum() if __name__ == '__main__': # 因变量个数,文件位置,次数 # 文件格式说明:第一行为数据名称(非数据),每一列不能全为0,请自行删除全为0的列 # 文件格式说明:第一列拒绝 index , 前面 m 列为自变量, 后面 n 列为因变量 # 文件类型说明: 支持 .csv .xlsx .xls Linear(1, "PLS2.csv", 1) |