36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
from sklearn.model_selection import learning_curve
|
|
|
|
# It is good to randomize the data before drawing Learning Curves
|
|
def randomize(X, Y):
|
|
permutation = np.random.permutation(Y.shape[0])
|
|
X2 = X[permutation,:]
|
|
Y2 = Y[permutation]
|
|
return X2, Y2
|
|
|
|
X2, y2 = randomize(X, y)
|
|
|
|
def draw_learning_curves(X, y, estimator, num_trainings):
|
|
train_sizes, train_scores, test_scores = learning_curve(
|
|
estimator, X2, y2, cv=None, n_jobs=1, train_sizes=np.linspace(.1, 1.0, num_trainings))
|
|
|
|
train_scores_mean = np.mean(train_scores, axis=1)
|
|
train_scores_std = np.std(train_scores, axis=1)
|
|
test_scores_mean = np.mean(test_scores, axis=1)
|
|
test_scores_std = np.std(test_scores, axis=1)
|
|
|
|
plt.grid()
|
|
|
|
plt.title("Learning Curves")
|
|
plt.xlabel("Training examples")
|
|
plt.ylabel("Score")
|
|
|
|
plt.plot(train_scores_mean, 'o-', color="g",
|
|
label="Training score")
|
|
plt.plot(test_scores_mean, 'o-', color="y",
|
|
label="Cross-validation score")
|
|
|
|
|
|
plt.legend(loc="best")
|
|
|
|
plt.show()
|