Logistic regression with L1 norm
Prepare data
from sklearn import datasets
import numpy as np
# Collect data
iris = datasets.load_iris()
X = iris.data[:, [2, 3]]
y = iris.target
# Split data
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=1, stratify=y)
# Standardize data
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)
L1 regulariztion
Increase or decrease C to make the regulariztion effect stronger or weaker, respectively.
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(penalty='l1', C=1.0, solver='liblinear', multi_class='ovr') # one-vs-rest
lr.fit(X_train_std, y_train)
print('Training accuracy:', lr.score(X_train_std, y_train))
print('Test accuracy:', lr.score(X_test_std, y_test))
Training accuracy: 0.9428571428571428
Test accuracy: 0.9555555555555556
# biases
lr.intercept_
array([-2.30657198, -0.67737202, -3.50790746])
# Weights
lr.coef_
array([[-4.79455056, 0. ],
[ 1.13405861, -0.79376053],
[ 2.84988176, 3.05759972]])
Impact of regularization strenght C
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.subplot(111)
colors = ['blue', 'green', 'red', 'cyan',
'magenta', 'yellow', 'black',
'pink', 'lightgreen', 'lightblue',
'gray', 'indigo', 'orange']
weights, params = [], []
for c in np.arange(-4., 6.):
lr = LogisticRegression(penalty='l1', C=10.**c, solver='liblinear',
multi_class='ovr', random_state=0)
lr.fit(X_train_std, y_train)
weights.append(lr.coef_[1])
params.append(10**c)
weights = np.array(weights)
for column, color in zip(range(weights.shape[1]), colors):
plt.plot(params, weights[:, column],
label=['petal length','petal width'],
color=color)
plt.axhline(0, color='black', linestyle='--', linewidth=3)
plt.xlim([10**(-5), 10**5])
plt.ylabel('weight coefficient')
plt.xlabel('C')
plt.xscale('log')
plt.legend(loc='upper left')
ax.legend(loc='upper center',
bbox_to_anchor=(1.38, 1.03),
ncol=1, fancybox=True)
#plt.savefig('images/04_07.png', dpi=300,
# bbox_inches='tight', pad_inches=0.2)
plt.show()