CARVIEW |
Implementing Machine Learning Pipelines with Apache Spark
Machine learning pipelines help turn data into predictions. Apache Spark makes it easy to build these pipelines for big data.

Image by Editor (Kanwal Mehreen) | Canva
Apache Spark is a tool for working with big data. It is free to use and very fast. Spark can manage large amounts of data that don’t fit in a computer’s memory. A machine learning pipeline is a series of steps to prepare data and train models. These steps include collecting data, cleaning it, selecting important features, training the model, and checking how well it works.
Spark makes it easy to build these pipelines. With Spark, companies can quickly analyze large amounts of data and create machine learning models. This helps them make better decisions based on the information they have. In this article, we will explain how to set up and use machine learning pipelines in Spark.
Components of a Machine Learning Pipeline in Spark
Spark’s MLlib library has many built-in tools. These tools can be linked together to build a complete machine learning process.
Transformers
Transformers change data in some way. They take a DataFrame and return a modified version of it. These are used for tasks like encoding categorical data or scaling numerical features. Examples include StringIndexer (for encoding) and StandardScaler (for scaling). Transformers are reusable and don’t change the original data permanently.
Estimators
Estimators learn from data to create models. They include algorithms like LogisticRegression and RandomForestClassifier. Estimators use a fit method to train on data, and they output a Model object that can make predictions.
Pipeline
A Pipeline is a tool to connect transformers and estimators into a single workflow. By organizing them in sequence, data flows smoothly from one step to the next. Pipelines make it easy to retrain models, repeat processes, and adjust parameters.
Let's go through a basic example of building a classification pipeline to predict customer churn. In this pipeline, we’ll:
- Load the Data: Import the dataset into Spark for processing.
- Preprocess the Data: Clean and prepare the data for modeling.
- Setup the Model: Prepare the logistic regression model.
- Train the Model: Fit a machine learning model to the data.
- Evaluate the Model: Check how well the model performs.
Initialize Spark Session and Load Dataset
First, we use SparkSession.builder to set up the session. Then, we load the customer churn dataset. This churn data is about bank customers who have closed their accounts.
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("MLPipeline").getOrCreate()
# Load dataset
data = spark.read.csv("/content/Customer Churn.csv", header=True, inferSchema=True)
# Show the first few rows of the dataset
data.show(5)
Data Preprocessing
First, we check the data for any missing values. If there are missing values, we remove those rows to make sure the data is complete. Next, we convert categorical data into numerical format so that the computer can understand it. We do this using methods like StringIndexer and OneHotEncoder. Finally, we combine all the features into a single vector and scale the data.
from pyspark.sql import functions as F
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
# Check for missing values
missing_values = data.select([F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) for c in data.columns])
# Drop rows with any missing values
data = data.na.drop()
# Identify categorical columns
categorical_columns = ['country', 'gender', 'credit_card', 'active_member']
# Create a list to hold the stages of the pipeline
stages = []
# Apply StringIndexer to convert categorical columns to numerical indices
for column in categorical_columns:
indexer = StringIndexer(inputCol=column, outputCol=column + "_index")
stages.append(indexer)
# Apply OneHotEncoder for categorical features
encoder = OneHotEncoder(inputCols=[column + "_index"], outputCols=[column + "_ohe"])
stages.append(encoder)
label_column = 'churn' # The label column
feature_columns = [column + "_ohe" for column in categorical_columns]
# Add numerical columns to the features list
numerical_columns = ['credit_score', 'age', 'tenure', 'balance', 'products_number', 'estimated_salary']
feature_columns += numerical_columns
# Create VectorAssembler to combine all feature columns
vector_assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
stages.append(vector_assembler)
# Scale the features using StandardScaler
scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withMean=True, withStd=True)
stages.append(scaler)
Logistic Regression Model Setup
We import LogisticRegression from pyspark.ml.classification. Next, we create a logistic regression model by using LogisticRegression().
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
# Logistic Regression Model
lr = LogisticRegression(featuresCol='scaled_features', labelCol=label_column)
stages.append(lr)
# Create and Run the Pipeline
pipeline = Pipeline(stages=stages)
Model Training and Predictions
We split the dataset into training and testing sets. Then, we fit the pipeline model to the training data and make predictions on the test data.
# Split data into training and testing sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)
# Fit the model
pipeline_model = pipeline.fit(train_data)
# Make Predictions
predictions = pipeline_model.transform(test_data)
# Show the predictions
predictions.select("prediction", label_column, "scaled_features").show(10)
Model Evaluation
We import MulticlassClassificationEvaluator from pyspark.ml.evaluation to evaluate our model's performance. We calculate the accuracy, precision, recall, and F1 score using the predictions from our model. Finally, we stop the Spark session to free up resources.
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Accuracy
evaluator_accuracy = MulticlassClassificationEvaluator(labelCol=label_column, predictionCol="prediction", metricName="accuracy")
accuracy = evaluator_accuracy.evaluate(predictions)
print(f"Accuracy: {accuracy}")
# Precision
evaluator_precision = MulticlassClassificationEvaluator(labelCol=label_column, predictionCol="prediction", metricName="weightedPrecision")
precision = evaluator_precision.evaluate(predictions)
print(f"Precision: {precision}")
# Recall
evaluator_recall = MulticlassClassificationEvaluator(labelCol=label_column, predictionCol="prediction", metricName="weightedRecall")
recall = evaluator_recall.evaluate(predictions)
print(f"Recall: {recall}")
# F1 Score
evaluator_f1 = MulticlassClassificationEvaluator(labelCol=label_column, predictionCol="prediction", metricName="f1")
f1_score = evaluator_f1.evaluate(predictions)
print(f"F1 Score: {f1_score}")
# Stop Spark session
spark.stop()
Conclusion
In this article, we learned about machine learning pipelines in Apache Spark. Pipelines help organize each step of the ML process. We started by loading and cleaning the customer churn dataset. Then, we transformed the data and created a logistic regression model. After training the model, we made predictions on new data. Finally, we evaluated the model's performance using accuracy, precision, recall, and F1 score.
Jayita Gulati is a machine learning enthusiast and technical writer driven by her passion for building machine learning models. She holds a Master's degree in Computer Science from the University of Liverpool.
- Building Modern Data Lakehouses on Google Cloud with Apache Iceberg…
- Creating a Data Science Pipeline for Real-Time Analytics Using…
- Implementing Data Governance in Data Science Pipelines: Techniques…
- Implementing Data Quality Assurance in Data Science Pipelines with…
- How to Build a Scalable Data Architecture with Apache Kafka
- Scaling Data Management Through Apache Gobblin
Latest Posts
- A Gentle Introduction to MCP Servers and Clients
- We Used 3 Feature Selection Techniques: This One Worked Best
- Debunking 5 Myths About Cloud Computing for Small Business (Sponsored)
- What Is Cross-Validation? A Plain English Guide with Diagrams
- Qwen Code Leverages Qwen3 as a CLI Agentic Programming Tool
- From Excel to Python: 7 Steps Analysts Can Take Today
Top Posts |
---|
- Building Machine Learning Application with Django
- Nano Banana Practical Prompting & Usage Guide
- 10 Useful Python One-Liners for Data Engineering
- Python for Data Science (Free 7-Day Mini-Course)
- Beginner’s Guide to Creating Your Own Python Shell with the cmd Module
- From Excel to Python: 7 Steps Analysts Can Take Today
- 7 Python Libraries Every Analytics Engineer Should Know
- 10 Python One-Liners to Optimize Your Hugging Face Transformers Pipelines
- Why Do Language Models Hallucinate?
- How To Use Synthetic Data To Build a Portfolio Project