{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Linear Model Selection and Regularization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- [Lab 2: Ridge Regression](#6.6.1-Ridge-Regression)\n", "- [Lab 2: The Lasso](#6.6.2-The-Lasso)\n", "- [Lab 3: Principal Components Regression](#6.7.1-Principal-Components-Regression)\n", "- [Lab 3: Partial Least Squares](#6.7.2-Partial-Least-Squares)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# %load ../standard_import.txt\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from sklearn.preprocessing import scale \n", "from sklearn import model_selection\n", "from sklearn.linear_model import LinearRegression, Ridge, RidgeCV, Lasso, LassoCV\n", "from sklearn.decomposition import PCA\n", "from sklearn.cross_decomposition import PLSRegression\n", "from sklearn.model_selection import KFold, cross_val_score\n", "from sklearn.metrics import mean_squared_error, r2_score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 263 entries, -Alan Ashby to -Willie Wilson\n", "Data columns (total 20 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 AtBat 263 non-null int64 \n", " 1 Hits 263 non-null int64 \n", " 2 HmRun 263 non-null int64 \n", " 3 Runs 263 non-null int64 \n", " 4 RBI 263 non-null int64 \n", " 5 Walks 263 non-null int64 \n", " 6 Years 263 non-null int64 \n", " 7 CAtBat 263 non-null int64 \n", " 8 CHits 263 non-null int64 \n", " 9 CHmRun 263 non-null int64 \n", " 10 CRuns 263 non-null int64 \n", " 11 CRBI 263 non-null int64 \n", " 12 CWalks 263 non-null int64 \n", " 13 League 263 non-null object \n", " 14 Division 263 non-null object \n", " 15 PutOuts 263 non-null int64 \n", " 16 Assists 263 non-null int64 \n", " 17 Errors 263 non-null int64 \n", " 18 Salary 263 non-null float64\n", " 19 NewLeague 263 non-null object \n", "dtypes: float64(1), int64(16), object(3)\n", "memory usage: 43.1+ KB\n" ] } ], "source": [ "# In R, I exported the dataset from package 'ISLR' to a csv file.\n", "df = pd.read_csv('data/Hitters.csv', index_col=0).dropna()\n", "df.index.name = 'Player'\n", "df.info()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AtBatHitsHmRunRunsRBIWalksYearsCAtBatCHitsCHmRunCRunsCRBICWalksLeagueDivisionPutOutsAssistsErrorsSalaryNewLeague
Player
-Alan Ashby31581724383914344983569321414375NW6324310475.0N
-Alvin Davis479130186672763162445763224266263AW8808214480.0A
-Andre Dawson496141206578371156281575225828838354NE200113500.0N
\n", "
" ], "text/plain": [ " AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits \\\n", "Player \n", "-Alan Ashby 315 81 7 24 38 39 14 3449 835 \n", "-Alvin Davis 479 130 18 66 72 76 3 1624 457 \n", "-Andre Dawson 496 141 20 65 78 37 11 5628 1575 \n", "\n", " CHmRun CRuns CRBI CWalks League Division PutOuts Assists \\\n", "Player \n", "-Alan Ashby 69 321 414 375 N W 632 43 \n", "-Alvin Davis 63 224 266 263 A W 880 82 \n", "-Andre Dawson 225 828 838 354 N E 200 11 \n", "\n", " Errors Salary NewLeague \n", "Player \n", "-Alan Ashby 10 475.0 N \n", "-Alvin Davis 14 480.0 A \n", "-Andre Dawson 3 500.0 N " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head(3)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 263 entries, -Alan Ashby to -Willie Wilson\n", "Data columns (total 6 columns):\n", " # Column Non-Null Count Dtype\n", "--- ------ -------------- -----\n", " 0 League_A 263 non-null bool \n", " 1 League_N 263 non-null bool \n", " 2 Division_E 263 non-null bool \n", " 3 Division_W 263 non-null bool \n", " 4 NewLeague_A 263 non-null bool \n", " 5 NewLeague_N 263 non-null bool \n", "dtypes: bool(6)\n", "memory usage: 3.6+ KB\n", " League_A League_N Division_E Division_W NewLeague_A \\\n", "Player \n", "-Alan Ashby False True False True False \n", "-Alvin Davis True False False True True \n", "-Andre Dawson False True True False False \n", "\n", " NewLeague_N \n", "Player \n", "-Alan Ashby True \n", "-Alvin Davis False \n", "-Andre Dawson True \n" ] } ], "source": [ "# Convert categorical variable into dummy/indicator variables.\n", "dummies = pd.get_dummies(df[['League', 'Division', 'NewLeague']])\n", "dummies.info()\n", "print(dummies.head(3))\n", "dummies = dummies.astype(float)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 263 entries, -Alan Ashby to -Willie Wilson\n", "Data columns (total 19 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 AtBat 263 non-null float64\n", " 1 Hits 263 non-null float64\n", " 2 HmRun 263 non-null float64\n", " 3 Runs 263 non-null float64\n", " 4 RBI 263 non-null float64\n", " 5 Walks 263 non-null float64\n", " 6 Years 263 non-null float64\n", " 7 CAtBat 263 non-null float64\n", " 8 CHits 263 non-null float64\n", " 9 CHmRun 263 non-null float64\n", " 10 CRuns 263 non-null float64\n", " 11 CRBI 263 non-null float64\n", " 12 CWalks 263 non-null float64\n", " 13 PutOuts 263 non-null float64\n", " 14 Assists 263 non-null float64\n", " 15 Errors 263 non-null float64\n", " 16 League_N 263 non-null float64\n", " 17 Division_W 263 non-null float64\n", " 18 NewLeague_N 263 non-null float64\n", "dtypes: float64(19)\n", "memory usage: 41.1+ KB\n" ] } ], "source": [ "y = df.Salary\n", "\n", "# Drop the column with the independent variable (Salary), and columns for which we created dummy variables\n", "X_ = df.drop(['Salary', 'League', 'Division', 'NewLeague'], axis=1).astype('float64')\n", "# Define the feature set X.\n", "X = pd.concat([X_, dummies[['League_N', 'Division_W', 'NewLeague_N']]], axis=1)\n", "X.info()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AtBatHitsHmRunRunsRBIWalksYearsCAtBatCHitsCHmRunCRunsCRBICWalksPutOutsAssistsErrorsLeague_NDivision_WNewLeague_N
Player
-Alan Ashby315.081.07.024.038.039.014.03449.0835.069.0321.0414.0375.0632.043.010.01.01.01.0
-Alvin Davis479.0130.018.066.072.076.03.01624.0457.063.0224.0266.0263.0880.082.014.00.01.00.0
-Andre Dawson496.0141.020.065.078.037.011.05628.01575.0225.0828.0838.0354.0200.011.03.01.00.01.0
\n", "
" ], "text/plain": [ " AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits \\\n", "Player \n", "-Alan Ashby 315.0 81.0 7.0 24.0 38.0 39.0 14.0 3449.0 835.0 \n", "-Alvin Davis 479.0 130.0 18.0 66.0 72.0 76.0 3.0 1624.0 457.0 \n", "-Andre Dawson 496.0 141.0 20.0 65.0 78.0 37.0 11.0 5628.0 1575.0 \n", "\n", " CHmRun CRuns CRBI CWalks PutOuts Assists Errors \\\n", "Player \n", "-Alan Ashby 69.0 321.0 414.0 375.0 632.0 43.0 10.0 \n", "-Alvin Davis 63.0 224.0 266.0 263.0 880.0 82.0 14.0 \n", "-Andre Dawson 225.0 828.0 838.0 354.0 200.0 11.0 3.0 \n", "\n", " League_N Division_W NewLeague_N \n", "Player \n", "-Alan Ashby 1.0 1.0 1.0 \n", "-Alvin Davis 0.0 1.0 0.0 \n", "-Andre Dawson 1.0 0.0 1.0 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.head(3)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AtBatHitsHmRunRunsRBIWalksYearsCAtBatCHitsCHmRunCRunsCRBICWalksLeagueNDivisionWPutOutsAssistsErrorsNewLeagueN
-Darryl Strawberry4751232776937241810471108292343267102261061
-Glenn Wilson584158157084425235863658265316134103312041
-Leon Durham48412720666567730068441164364583771012318071
\n", "
" ], "text/plain": [ " AtBat Hits HmRun Runs RBI Walks Years CAtBat \\\n", "-Darryl Strawberry 475 123 27 76 93 72 4 1810 \n", "-Glenn Wilson 584 158 15 70 84 42 5 2358 \n", "-Leon Durham 484 127 20 66 65 67 7 3006 \n", "\n", " CHits CHmRun CRuns CRBI CWalks LeagueN DivisionW \\\n", "-Darryl Strawberry 471 108 292 343 267 1 0 \n", "-Glenn Wilson 636 58 265 316 134 1 0 \n", "-Leon Durham 844 116 436 458 377 1 0 \n", "\n", " PutOuts Assists Errors NewLeagueN \n", "-Darryl Strawberry 226 10 6 1 \n", "-Glenn Wilson 331 20 4 1 \n", "-Leon Durham 1231 80 7 1 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train = pd.read_csv('data/Hitters_X_train.csv', index_col=0)\n", "y_train = pd.read_csv('data/Hitters_y_train.csv', index_col=0)\n", "X_test = pd.read_csv('data/Hitters_X_test.csv', index_col=0)\n", "y_test = pd.read_csv('data/Hitters_y_test.csv', index_col=0)\n", "X_train.head(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Ridge Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "The __sklearn Ridge()__ function optimizes:\n", " $$ ||X\\beta - y||^2_2 + \\alpha ||\\beta||^2_2 $$\n", "which is equivalent to optimizing\n", " $$ \\frac{1}{N}||X\\beta - y||^2_2 + \\frac{\\alpha}{N} ||\\beta||^2_2 $$" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "alphas = 10**np.linspace(10,-2,100)*0.5\n", "\n", "ridge = Ridge()\n", "coefs = []\n", "\n", "for a in alphas:\n", " ridge.set_params(alpha=a)\n", " ridge.fit(scale(X), y)\n", " # print(sum(ridge.coef_))\n", " coefs.append(ridge.coef_)\n", "\n", "ax = plt.gca()\n", "ax.plot(alphas, coefs)\n", "ax.set_xscale('log')\n", "ax.set_xlim(ax.get_xlim()[::-1]) # reverse axis\n", "plt.axis('tight')\n", "plt.xlabel('alpha')\n", "plt.ylabel('weights')\n", "plt.title('Ridge coefficients as a function of the regularization');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above plot shows that the Ridge coefficients get larger when we decrease alpha." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "scaler = StandardScaler().fit(X_train)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "98866.37786033245\n", "0.4868771617300639\n" ] } ], "source": [ "\n", "ridge2 = Ridge(alpha=X_train.shape[1]/10)\n", "ridge2.fit(scaler.transform(X_train), y_train)\n", "pred = ridge2.predict(scaler.transform(X_test))\n", "\n", "print(mean_squared_error(y_test, pred))\n", "print(r2_score(y_test, pred))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AtBat -232.417538\n", "Hits 185.399421\n", "HmRun 61.693947\n", "Runs -82.295866\n", "RBI 47.624703\n", "Walks 156.248223\n", "Years -122.339806\n", "CAtBat -82.785224\n", "CHits 134.265011\n", "CHmRun -19.297365\n", "CRuns 369.449829\n", "CRBI -50.380718\n", "CWalks -103.110248\n", "PutOuts 22.500903\n", "Assists -64.528856\n", "Errors 90.415338\n", "League_N 37.281989\n", "Division_W -13.502786\n", "NewLeague_N 3.273231\n", "dtype: float64" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.Series(ridge2.coef_.flatten(), index=X.columns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Alpha = $10^{10}$ \n", "This big penalty shrinks the coefficients to a very large degree and makes the model more biased, resulting in a higher MSE." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "193253.09741651407\n", "-0.002995962802482266\n" ] } ], "source": [ "ridge2.set_params(alpha=10**10)\n", "ridge2.fit(scale(X_train), y_train)\n", "pred = ridge2.predict(scale(X_test))\n", "\n", "print(mean_squared_error(y_test, pred))\n", "print(r2_score(y_test, pred)) ## RSS > TSS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Compute the regularization path using RidgeCV" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RidgeCV(alphas=array([5.00000000e+09, 3.78231664e+09, 2.86118383e+09, 2.16438064e+09,\n",
       "       1.63727458e+09, 1.23853818e+09, 9.36908711e+08, 7.08737081e+08,\n",
       "       5.36133611e+08, 4.05565415e+08, 3.06795364e+08, 2.32079442e+08,\n",
       "       1.75559587e+08, 1.32804389e+08, 1.00461650e+08, 7.59955541e+07,\n",
       "       5.74878498e+07, 4.34874501e+07, 3.28966612e+07, 2.48851178e+07,\n",
       "       1.88246790e+07, 1.42401793e+0...\n",
       "       3.06795364e+00, 2.32079442e+00, 1.75559587e+00, 1.32804389e+00,\n",
       "       1.00461650e+00, 7.59955541e-01, 5.74878498e-01, 4.34874501e-01,\n",
       "       3.28966612e-01, 2.48851178e-01, 1.88246790e-01, 1.42401793e-01,\n",
       "       1.07721735e-01, 8.14875417e-02, 6.16423370e-02, 4.66301673e-02,\n",
       "       3.52740116e-02, 2.66834962e-02, 2.01850863e-02, 1.52692775e-02,\n",
       "       1.15506485e-02, 8.73764200e-03, 6.60970574e-03, 5.00000000e-03]),\n",
       "        scoring='neg_mean_squared_error')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RidgeCV(alphas=array([5.00000000e+09, 3.78231664e+09, 2.86118383e+09, 2.16438064e+09,\n", " 1.63727458e+09, 1.23853818e+09, 9.36908711e+08, 7.08737081e+08,\n", " 5.36133611e+08, 4.05565415e+08, 3.06795364e+08, 2.32079442e+08,\n", " 1.75559587e+08, 1.32804389e+08, 1.00461650e+08, 7.59955541e+07,\n", " 5.74878498e+07, 4.34874501e+07, 3.28966612e+07, 2.48851178e+07,\n", " 1.88246790e+07, 1.42401793e+0...\n", " 3.06795364e+00, 2.32079442e+00, 1.75559587e+00, 1.32804389e+00,\n", " 1.00461650e+00, 7.59955541e-01, 5.74878498e-01, 4.34874501e-01,\n", " 3.28966612e-01, 2.48851178e-01, 1.88246790e-01, 1.42401793e-01,\n", " 1.07721735e-01, 8.14875417e-02, 6.16423370e-02, 4.66301673e-02,\n", " 3.52740116e-02, 2.66834962e-02, 2.01850863e-02, 1.52692775e-02,\n", " 1.15506485e-02, 8.73764200e-03, 6.60970574e-03, 5.00000000e-03]),\n", " scoring='neg_mean_squared_error')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ridgecv = RidgeCV(alphas=alphas, scoring='neg_mean_squared_error')\n", "ridgecv.fit(scale(X_train), y_train)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "115.5064850041579" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ridgecv.alpha_" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "97384.92959172589" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ridge2.set_params(alpha=ridgecv.alpha_)\n", "ridge2.fit(scale(X_train), y_train)\n", "mean_squared_error(y_test, ridge2.predict(scale(X_test)))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AtBat 7.576771\n", "Hits 22.596030\n", "HmRun 18.971990\n", "Runs 20.193945\n", "RBI 21.063875\n", "Walks 55.713281\n", "Years -4.687149\n", "CAtBat 20.496892\n", "CHits 29.230247\n", "CHmRun 14.293124\n", "CRuns 35.881788\n", "CRBI 20.212172\n", "CWalks 24.419768\n", "PutOuts 16.128910\n", "Assists -44.102264\n", "Errors 54.624503\n", "League_N 5.771464\n", "Division_W -0.293713\n", "NewLeague_N 11.137518\n", "dtype: float64" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.Series(ridge2.coef_.flatten(), index=X.columns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### The Lasso" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Scikit-learn" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lasso = Lasso(max_iter=10000)\n", "coefs = []\n", "\n", "for a in alphas:\n", " lasso.set_params(alpha=a)\n", " lasso.fit(scale(X_train), y_train)\n", " coefs.append(lasso.coef_)\n", "\n", "ax = plt.gca()\n", "ax.plot(alphas, coefs)\n", "ax.set_xscale('log')\n", "ax.set_xlim(ax.get_xlim()[::-1]) # reverse axis\n", "plt.axis('tight')\n", "plt.xlabel('alpha')\n", "plt.ylabel('weights')\n", "plt.title('Lasso coefficients as a function of the regularization');" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LassoCV(alphas=array([5.00000000e+09, 3.78231664e+09, 2.86118383e+09, 2.16438064e+09,\n",
       "       1.63727458e+09, 1.23853818e+09, 9.36908711e+08, 7.08737081e+08,\n",
       "       5.36133611e+08, 4.05565415e+08, 3.06795364e+08, 2.32079442e+08,\n",
       "       1.75559587e+08, 1.32804389e+08, 1.00461650e+08, 7.59955541e+07,\n",
       "       5.74878498e+07, 4.34874501e+07, 3.28966612e+07, 2.48851178e+07,\n",
       "       1.88246790e+07, 1.42401793e+0...\n",
       "       3.06795364e+00, 2.32079442e+00, 1.75559587e+00, 1.32804389e+00,\n",
       "       1.00461650e+00, 7.59955541e-01, 5.74878498e-01, 4.34874501e-01,\n",
       "       3.28966612e-01, 2.48851178e-01, 1.88246790e-01, 1.42401793e-01,\n",
       "       1.07721735e-01, 8.14875417e-02, 6.16423370e-02, 4.66301673e-02,\n",
       "       3.52740116e-02, 2.66834962e-02, 2.01850863e-02, 1.52692775e-02,\n",
       "       1.15506485e-02, 8.73764200e-03, 6.60970574e-03, 5.00000000e-03]),\n",
       "        cv=10, max_iter=10000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LassoCV(alphas=array([5.00000000e+09, 3.78231664e+09, 2.86118383e+09, 2.16438064e+09,\n", " 1.63727458e+09, 1.23853818e+09, 9.36908711e+08, 7.08737081e+08,\n", " 5.36133611e+08, 4.05565415e+08, 3.06795364e+08, 2.32079442e+08,\n", " 1.75559587e+08, 1.32804389e+08, 1.00461650e+08, 7.59955541e+07,\n", " 5.74878498e+07, 4.34874501e+07, 3.28966612e+07, 2.48851178e+07,\n", " 1.88246790e+07, 1.42401793e+0...\n", " 3.06795364e+00, 2.32079442e+00, 1.75559587e+00, 1.32804389e+00,\n", " 1.00461650e+00, 7.59955541e-01, 5.74878498e-01, 4.34874501e-01,\n", " 3.28966612e-01, 2.48851178e-01, 1.88246790e-01, 1.42401793e-01,\n", " 1.07721735e-01, 8.14875417e-02, 6.16423370e-02, 4.66301673e-02,\n", " 3.52740116e-02, 2.66834962e-02, 2.01850863e-02, 1.52692775e-02,\n", " 1.15506485e-02, 8.73764200e-03, 6.60970574e-03, 5.00000000e-03]),\n", " cv=10, max_iter=10000)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lassocv = LassoCV(alphas=alphas, cv=10, max_iter=10000)\n", "lassocv.fit(scale(X_train), y_train.values.ravel())" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "28.6118382967511" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lassocv.alpha_" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "102773.23894326504" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lasso.set_params(alpha=lassocv.alpha_)\n", "lasso.fit(scale(X_train), y_train)\n", "mean_squared_error(y_test, lasso.predict(scale(X_test)))" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AtBat 0.000000\n", "Hits 0.000000\n", "HmRun 3.338801\n", "Runs 0.000000\n", "RBI 30.295110\n", "Walks 104.483629\n", "Years -0.000000\n", "CAtBat 0.000000\n", "CHits 0.000000\n", "CHmRun 0.000000\n", "CRuns 133.994637\n", "CRBI 0.000000\n", "CWalks 0.000000\n", "PutOuts 3.290144\n", "Assists -52.513370\n", "Errors 77.650527\n", "League_N 0.000000\n", "Division_W 0.000000\n", "NewLeague_N 0.000000\n", "dtype: float64" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Some of the coefficients are now reduced to exactly zero.\n", "pd.Series(lasso.coef_, index=X.columns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Principal Components Regression" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(19, 19)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
012345
00.198290-0.3837840.0886260.0319670.028117-0.070646
10.195861-0.3772710.0740320.017982-0.004652-0.082240
20.204369-0.237136-0.216186-0.2358310.077660-0.149646
30.198337-0.377721-0.017166-0.049942-0.038536-0.136660
40.235174-0.314531-0.073085-0.1389850.024299-0.111675
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5\n", "0 0.198290 -0.383784 0.088626 0.031967 0.028117 -0.070646\n", "1 0.195861 -0.377271 0.074032 0.017982 -0.004652 -0.082240\n", "2 0.204369 -0.237136 -0.216186 -0.235831 0.077660 -0.149646\n", "3 0.198337 -0.377721 -0.017166 -0.049942 -0.038536 -0.136660\n", "4 0.235174 -0.314531 -0.073085 -0.138985 0.024299 -0.111675" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pca = PCA()\n", "X_reduced = pca.fit_transform(scale(X))\n", "\n", "print(pca.components_.shape)\n", "pd.DataFrame(pca.components_.T).loc[:4,:5]" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(263, 19)\n", "[7.30749065e+00 4.16564336e+00 2.03815790e+00 1.56251989e+00\n", " 1.00246702e+00 8.28606396e-01 6.91971644e-01 5.14987685e-01\n", " 2.51690120e-01 1.85522541e-01 1.37768673e-01 1.27966318e-01\n", " 9.59512911e-02 6.12697946e-02 5.21743719e-02 2.81122622e-02\n", " 1.41463854e-02 4.88096009e-03 1.19182631e-03]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
012345
0-0.0096491.8705221.265145-0.9354811.1096361.211972
10.411434-2.429422-0.909193-0.2642121.2320311.826617
23.4668220.8259470.555469-1.616726-0.857488-1.028712
3-2.558317-0.2309840.519642-2.176251-0.8203011.491696
41.027702-1.5735371.3313823.4940040.9834270.513675
\n", "
" ], "text/plain": [ " 0 1 2 3 4 5\n", "0 -0.009649 1.870522 1.265145 -0.935481 1.109636 1.211972\n", "1 0.411434 -2.429422 -0.909193 -0.264212 1.232031 1.826617\n", "2 3.466822 0.825947 0.555469 -1.616726 -0.857488 -1.028712\n", "3 -2.558317 -0.230984 0.519642 -2.176251 -0.820301 1.491696\n", "4 1.027702 -1.573537 1.331382 3.494004 0.983427 0.513675" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(X_reduced.shape)\n", "print(pca.explained_variance_)\n", "pd.DataFrame(X_reduced).loc[:4,:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above principal components are the same as in R." ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.3831, 0.6015, 0.7084, 0.7903, 0.8429, 0.8863, 0.9226, 0.9496,\n", " 0.9628, 0.9725, 0.9797, 0.9864, 0.9914, 0.9946, 0.9973, 0.9988,\n", " 0.9995, 0.9998, 0.9999])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Variance explained by the principal components\n", "np.cumsum(np.round(pca.explained_variance_ratio_, decimals=4))" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 10-fold CV, with shuffle\n", "n = len(X_reduced)\n", "kf_10 = KFold(n_splits=10, shuffle=True, random_state=1)\n", "\n", "regr = LinearRegression()\n", "mse = []\n", "\n", "# Calculate MSE with only the intercept (no principal components in regression)\n", "score = -1 * cross_val_score(regr, np.ones((n,1)), y.ravel(), cv=kf_10, scoring='neg_mean_squared_error').mean() \n", "mse.append(score)\n", "\n", "# Calculate MSE using CV for the 19 principle components, adding one component at the time.\n", "for i in np.arange(1, 20):\n", " score = -1*cross_val_score(regr, X_reduced[:,:i], y.ravel(), cv=kf_10, scoring='neg_mean_squared_error').mean()\n", " mse.append(score)\n", " \n", "plt.plot(mse, '-v')\n", "plt.xlabel('Number of principal components in regression')\n", "plt.ylabel('MSE')\n", "plt.title('Salary')\n", "plt.xlim(xmin=-1);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above plot indicates that the lowest training MSE is reached when doing regression on 18 components." ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 106.36859204, -21.60350456, 24.2942534 , -36.9858579 ,\n", " -58.41402748, 62.20632652, 24.63862038, 15.82817701,\n", " 29.57680773, 99.64801199, -30.11209105, 20.99269291,\n", " 72.40210574, -276.68551696, -74.17098665, 422.72580227,\n", " -347.05662353, -561.59691587, -83.25441536])" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "regr_test = LinearRegression()\n", "regr_test.fit(X_reduced, y)\n", "regr_test.coef_" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 1 }