博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
COMP7404 Machine Learing——Learning Curve & Validation Curve
阅读量:2135 次
发布时间:2019-04-30

本文共 3567 字,大约阅读时间需要 11 分钟。

Learning Curve 

import pandas as pdimport numpy as npdf = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data', header=None)from sklearn.preprocessing import LabelEncoderfrom sklearn.model_selection import train_test_splitX = df.loc[:, 2:].valuesy = df.loc[:, 1].valuesle = LabelEncoder()y = le.fit_transform(y)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, stratify=y, random_state=1)from sklearn.preprocessing import StandardScalerfrom sklearn.decomposition import PCAfrom sklearn.linear_model import LogisticRegressionfrom sklearn.pipeline import make_pipelineimport matplotlib.pyplot as pltfrom sklearn.model_selection import learning_curvepipe_lr = make_pipeline(StandardScaler(), LogisticRegression(penalty='l2', random_state=1, solver='liblinear'))train_sizes, train_scores, test_scores = learning_curve(estimator=pipe_lr, X=X_train, y=y_train,                                          train_sizes=np.linspace(0.1, 1.0, 10), cv=10, n_jobs=1)train_mean = np.mean(train_scores, axis=1)train_std = np.std(train_scores, axis=1)test_mean = np.mean(test_scores, axis=1)test_std = np.std(test_scores, axis=1)plt.plot(train_sizes, train_mean, color='blue', marker='o', markersize=5, label='training accuracy')plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue')plt.plot(train_sizes, test_mean, color='green', linestyle='--', marker='s', markersize=5, label='validation accuracy')plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green')plt.grid()plt.xlabel('Number of training samples')plt.ylabel('Accuracy')plt.legend(loc='lower right')plt.ylim([0.8, 1.03])plt.tight_layout()plt.show()

 

Validation Curve

import pandas as pdimport numpy as npdf = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data', header=None)from sklearn.preprocessing import LabelEncoderfrom sklearn.model_selection import train_test_splitX = df.loc[:, 2:].valuesy = df.loc[:, 1].valuesle = LabelEncoder()y = le.fit_transform(y)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, stratify=y, random_state=1)print(len(X_train))from sklearn.preprocessing import StandardScalerfrom sklearn.decomposition import PCAfrom sklearn.linear_model import LogisticRegressionfrom sklearn.pipeline import make_pipelineimport matplotlib.pyplot as pltfrom sklearn.model_selection import validation_curvepipe_lr = make_pipeline(StandardScaler(), LogisticRegression(penalty='l2', random_state=1, solver='liblinear'))param_range = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]train_scores, test_scores = validation_curve(estimator=pipe_lr, X=X_train, y=y_train, param_name='logisticregression__C',                 param_range=param_range, cv=10)train_mean = np.mean(train_scores, axis=1)train_std = np.std(train_scores, axis=1)test_mean = np.mean(test_scores, axis=1)test_std = np.std(test_scores, axis=1)plt.plot(param_range, train_mean, color='blue', marker='o', markersize=5, label='training accuracy')plt.fill_between(param_range, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue')plt.plot(param_range, test_mean, color='green', linestyle='--', marker='s', markersize=5, label='validation accuracy')plt.fill_between(param_range, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green')plt.grid()plt.xscale('log')plt.legend(loc='lower right')plt.xlabel('Parameter C')plt.ylabel('Accuracy')plt.ylim([0.8, 1.0])plt.tight_layout()plt.show()

 

转载地址:http://xmygf.baihongyu.com/

你可能感兴趣的文章
数据重生:让神经机器翻译中的不活跃样本“复活”
查看>>
【Java】【28】提高List的removeAll方法的效率
查看>>
【JS】【31】读取json文件
查看>>
OpenSSL源代码学习[转]
查看>>
Spring下载地址
查看>>
google app api相关(商用)
查看>>
linux放音乐cd
查看>>
GridView+存储过程实现'真分页'
查看>>
flask_migrate
查看>>
解决activemq多消费者并发处理
查看>>
UDP连接和TCP连接的异同
查看>>
hibernate 时间段查询
查看>>
java操作cookie 实现两周内自动登录
查看>>
Tomcat 7优化前及优化后的性能对比
查看>>
Java Guava中的函数式编程讲解
查看>>
Eclipse Memory Analyzer 使用技巧
查看>>
tomcat连接超时
查看>>
谈谈编程思想
查看>>
iOS MapKit导航及地理转码辅助类
查看>>
检测iOS的网络可用性并打开网络设置
查看>>