Linear Regression Using Scikit-Learn Library

Linear Regression is the simplest and the first machine learning algorithm one would study. In this post, I further explain what linear regression is and will demonstrate how to implement a model in Python using the Scikit-learn library.
Machine Learning • Regression • Python

Linear Regression in Python with scikit-learn (Practical, Clean, and Reproducible)

Linear regression is often the first model people learn—and for good reason: it’s fast, interpretable, and a strong baseline when you want a direct read on how an input relates to an output. This walkthrough explains what linear regression is, how to fit it with scikit-learn, how to interpret the coefficient/intercept, and how to evaluate the model without fooling yourself.

What is linear regression?

Linear regression models the relationship between an input X and an output y using a straight line:

y = mX + b
  • m is the coefficient (slope): how much y changes for a 1-unit change in X.
  • b is the intercept: the predicted y when X = 0.

Strong opinion: treat linear regression as your “truth serum” baseline. If a complex model cannot beat it on a clean test set, the pipeline—not the algorithm—is the problem.

When linear regression is a good idea

  • Interpretability matters (you need to explain “why” a prediction changed).
  • You want a baseline before moving to nonlinear models.
  • Data is limited and you want to avoid overfitting.
Reality check: linear regression assumes a linear relationship and is sensitive to outliers and omitted variable bias. It’s not “wrong” when assumptions break—it’s just a baseline that tells you the relationship isn’t well-explained by a line.

Step 1 — Import libraries

import pandas as pd
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
      

Step 2 — Read the dataset

The original example uses a simple dataset with one input feature (e.g., TV advertising spend) and one target (sales). :contentReference[oaicite:1]{index=1} Below is the same structure with slightly improved hygiene (preview + basic checks).

df = pd.read_csv("data.csv")

print(df.head())
print(df.isna().sum())
      

Step 3 — Fit a linear regression model

The core scikit-learn workflow is: separate X and y → split train/test → fit → predict → evaluate. The original post fits directly on all data and prints the coefficient/intercept. :contentReference[oaicite:2]{index=2} Below is a more production-realistic version that avoids training on the full dataset.

Interpretation of outputs

  • model.coef_ = slope (m)
  • model.intercept_ = intercept (b)

Common pitfall: in some posts, “slope” and “intercept” labels are swapped. The coefficient is the slope; the intercept is the intercept.

# Example: predict Sales (y) from TV spend (X)
X = df[["TV"]]          # 2D array-like
y = df["sales"]         # 1D vector-like

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

model = LinearRegression()
model.fit(X_train, y_train)

slope = float(model.coef_[0])
intercept = float(model.intercept_)

print("Slope (coef):", slope)
print("Intercept:", intercept)
      

Step 4 — Predict and visualize

A scatterplot plus the fitted regression line is the fastest way to sanity-check whether the relationship is roughly linear. The original post uses this plotting approach. :contentReference[oaicite:3]{index=3}

y_pred = model.predict(X_test)

# Plot training data (optional) + regression line over full X range
plt.scatter(X, y, alpha=0.8)
x_line = pd.DataFrame({"TV": sorted(X["TV"].values)})
y_line = model.predict(x_line)

plt.plot(x_line["TV"], y_line, linewidth=2)
plt.xlabel("TV")
plt.ylabel("Sales")
plt.title("Linear Regression: Sales vs TV")
plt.show()
      

Step 5 — Evaluate (MAE, RMSE, R²)

If you only look at the line on the training data, you can convince yourself everything is “great.” Evaluation on a held-out test split keeps you honest.

mae = mean_absolute_error(y_test, y_pred)
rmse = mean_squared_error(y_test, y_pred, squared=False)
r2 = r2_score(y_test, y_pred)

print("MAE :", mae)
print("RMSE:", rmse)
print("R^2 :", r2)
      

Practical improvements (worth doing every time)

  • Check residuals: pattern-free residuals are what you want; patterns mean your line is missing structure.
  • Handle outliers: a few extreme points can swing the slope.
  • Consider multiple regression: add more features if the problem is not truly 1D.
  • Use regularization (Ridge/Lasso) when features are many or correlated.
# Quick residual plot (optional)
residuals = y_test - y_pred

plt.scatter(y_pred, residuals, alpha=0.8)
plt.axhline(0, linewidth=2)
plt.xlabel("Predicted")
plt.ylabel("Residuals (Actual - Predicted)")
plt.title("Residual Plot")
plt.show()
      

Wrap-up

Linear regression is simple, but it’s not a toy. It’s a fast, interpretable baseline that often beats more complex models when data is limited or the signal is clean. Fit the model, interpret the coefficient and intercept, plot the line, and evaluate on a test set. If the residuals show structure, you’ve learned something important: your relationship isn’t “just a line.”

Reference baseline for this rewrite: “Linear Regression Using Scikit-Learn Library” (Jose Dominguez, Digital Studio Stream). :contentReference[oaicite:4]{index=4}