MLflow on Databricks

TL;DR

MLflow is an open-source platform for managing the entire ML lifecycle. On Databricks it's pre-installed and integrated with Unity Catalog. It tracks experiments (what you tried), logs metrics (how well it worked), stores models (the artifacts), and deploys them to production — all from your notebook.

Explain Like I'm 12

Imagine you're entering a science fair. You try a bunch of different experiments — maybe growing plants with different fertilizers, different amounts of water, under different lights. For each experiment, you write down exactly what you did (the setup), what happened (the results), and whether the plant grew tall or not (the score).

After 20 experiments, you look at your lab notebook, compare all the results, and pick the best one to present to the judges. You put that winning plant on a nice display stand with a label so everyone knows which one it is.

MLflow is that lab notebook. It automatically records every experiment you run (what settings you used, how accurate the model was), lets you compare them side by side, and then helps you pick the best model and put it on a "display stand" (production) where real users can use it. The best part? On Databricks, it does most of this automatically — you don't even have to remember to write things down.

The 4 Components of MLflow

MLflow is not one tool — it's four tools bundled together. Think of them as four stages of the ML assembly line, each solving a different problem in the journey from "I have an idea" to "this model is serving predictions in production."

Component What It Does Analogy
MLflow Tracking Logs parameters, metrics, and artifacts for every experiment run Your lab notebook — records what you tried and what happened
MLflow Models A standard packaging format that wraps any ML framework (sklearn, PyTorch, XGBoost, custom) A universal shipping box — no matter what's inside, everyone knows how to open it
Model Registry Version control for models — register, version, approve, and promote models through stages A trophy case with labels: "v1 — testing", "v2 — champion", "v3 — retired"
MLflow Projects A convention for packaging code as reproducible runs (conda/Docker environments) A recipe card with ingredients list — anyone can reproduce the exact same dish
On Databricks, focus on Tracking and Model Registry. MLflow Projects are less relevant because Databricks notebooks already handle environment management and reproducibility. You'll spend 90% of your time with Tracking (logging experiments) and the Model Registry (promoting models to production).

Experiment Tracking

This is the core of MLflow and the part you'll use every single day. The idea is simple: every time you train a model, you create a run. Each run lives inside an experiment (a logical grouping, usually one per project or notebook). Inside that run, you log three things:

  • Parameters — the inputs to your model (learning rate, number of trees, regularization strength)
  • Metrics — the outputs you care about (accuracy, RMSE, F1 score, AUC)
  • Artifacts — any files produced (the model itself, plots, confusion matrices, feature importance charts)

On Databricks, every notebook is automatically associated with an experiment. You don't need to set anything up. Just start logging.

Manual Logging

Here's the most explicit way to track an experiment. You control exactly what gets logged:

import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

# Load your data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Start a run — everything inside this block gets tracked
with mlflow.start_run(run_name="rf-baseline"):
    # Define hyperparameters
    n_estimators = 100
    max_depth = 10

    # Log parameters (the knobs you turned)
    mlflow.log_param("n_estimators", n_estimators)
    mlflow.log_param("max_depth", max_depth)
    mlflow.log_param("test_size", 0.2)

    # Train the model
    model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
    model.fit(X_train, y_train)

    # Evaluate
    predictions = model.predict(X_test)
    acc = accuracy_score(y_test, predictions)
    f1 = f1_score(y_test, predictions, average="weighted")

    # Log metrics (how well it worked)
    mlflow.log_metric("accuracy", acc)
    mlflow.log_metric("f1_score", f1)

    # Log the model artifact
    mlflow.sklearn.log_model(model, "random-forest-model")

    print(f"Accuracy: {acc:.4f}, F1: {f1:.4f}")

After running this cell in your Databricks notebook, click the "Experiment" icon in the right sidebar. You'll see your run with all the parameters and metrics, ready to compare.

Autologging — The Easy Way

Manual logging is fine for learning, but in practice, you want MLflow to capture everything automatically. That's what autologging does:

import mlflow

# One line — that's it. Put this at the top of your notebook.
mlflow.autolog()

# Now train as normal — MLflow captures everything
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = GradientBoostingClassifier(n_estimators=200, max_depth=5, learning_rate=0.1)
model.fit(X_train, y_train)

# MLflow automatically logged:
# - All hyperparameters (n_estimators, max_depth, learning_rate, etc.)
# - Training metrics (accuracy on training set)
# - The model artifact
# - Feature importance plot
# - A model signature (input/output schema)
Always use mlflow.autolog() as the first line in your notebook. It captures everything automatically for scikit-learn, XGBoost, LightGBM, PyTorch, TensorFlow, Keras, Spark ML, and Statsmodels. You can always add manual log_param() or log_metric() calls on top for custom metrics that autolog doesn't capture.

Comparing Runs in the UI

After running several experiments with different hyperparameters, you'll want to compare them. In the Databricks experiment UI:

  1. Click the Experiment icon in the notebook sidebar (or navigate to the Experiments page)
  2. Select the runs you want to compare (checkbox each row)
  3. Click "Compare" — you'll get a side-by-side view of parameters and metrics
  4. Use the parallel coordinates plot to visualize how hyperparameters relate to performance
  5. Sort by any metric column to quickly find the best run

You can also query runs programmatically:

# Search for the best run by accuracy
best_run = mlflow.search_runs(
    order_by=["metrics.accuracy DESC"],
    max_results=1
)

print(f"Best run ID: {best_run.iloc[0]['run_id']}")
print(f"Best accuracy: {best_run.iloc[0]['metrics.accuracy']:.4f}")

Logging Models

When you log a model with MLflow, you're not just saving a pickle file. You're creating a standard package that includes the serialized model, its dependencies, and a description of its expected inputs and outputs. This standard format is what makes everything downstream — registry, serving, batch inference — work seamlessly.

Framework-Specific Logging

MLflow has built-in "flavors" for popular ML frameworks. Each flavor knows how to serialize and deserialize that framework's models:

import mlflow
from sklearn.ensemble import RandomForestClassifier

with mlflow.start_run():
    model = RandomForestClassifier(n_estimators=100)
    model.fit(X_train, y_train)

    # sklearn flavor — MLflow knows how to save/load sklearn models
    mlflow.sklearn.log_model(model, "model")

Model Signatures

A signature describes the model's expected input and output schema. Think of it as a contract: "I expect a DataFrame with these columns and types, and I'll return predictions in this format." This is critical for serving — the endpoint validates incoming requests against the signature.

import mlflow
from mlflow.models import infer_signature
from sklearn.ensemble import GradientBoostingClassifier

with mlflow.start_run():
    model = GradientBoostingClassifier(n_estimators=200, max_depth=5)
    model.fit(X_train, y_train)

    # Infer the signature from training data and predictions
    predictions = model.predict(X_test)
    signature = infer_signature(X_test, predictions)

    # Log with signature and an input example
    mlflow.sklearn.log_model(
        model,
        "gb-classifier",
        signature=signature,
        input_example=X_test.iloc[:3]  # first 3 rows as example
    )

Custom Models with pyfunc

What if your model isn't just a single sklearn estimator? Maybe it's a pipeline with custom preprocessing, or a model that calls an external API. Use mlflow.pyfunc to wrap anything:

import mlflow.pyfunc
import pandas as pd

class CustomChurnModel(mlflow.pyfunc.PythonModel):
    """Custom model that wraps preprocessing + prediction."""

    def load_context(self, context):
        """Called once when the model is loaded."""
        import joblib
        self.preprocessor = joblib.load(context.artifacts["preprocessor"])
        self.model = joblib.load(context.artifacts["model"])

    def predict(self, context, model_input: pd.DataFrame) -> pd.DataFrame:
        """Called for each prediction request."""
        processed = self.preprocessor.transform(model_input)
        predictions = self.model.predict_proba(processed)[:, 1]
        return pd.DataFrame({"churn_probability": predictions})

# Log the custom model
with mlflow.start_run():
    mlflow.pyfunc.log_model(
        artifact_path="churn-model",
        python_model=CustomChurnModel(),
        artifacts={
            "preprocessor": "/path/to/preprocessor.pkl",
            "model": "/path/to/model.pkl"
        },
        pip_requirements=["scikit-learn==1.4.0", "pandas>=2.0"]
    )
MLflow uses a standard "MLmodel" format. When you log a model, MLflow creates a directory with an MLmodel YAML file (metadata), the serialized model, a conda.yaml (dependencies), and a requirements.txt. This standard packaging means any framework — sklearn, PyTorch, TensorFlow, XGBoost, or custom — can be loaded and served the exact same way. You never have to worry about "how do I deserialize this?" again.

Model Registry

Experiment tracking answers "which model is best?" The Model Registry answers "which model is in production?" It's version control for models — like Git, but for trained artifacts instead of code.

The Model Lifecycle

Every registered model goes through stages:

  1. None — just registered, not assigned to any stage
  2. Staging — being validated (shadow traffic, A/B testing, integration tests)
  3. Production — serving live traffic, this is the champion model
  4. Archived — retired but kept for audit/rollback purposes

Registering a Model

import mlflow

# Option 1: Register during logging (most common)
with mlflow.start_run():
    mlflow.sklearn.log_model(
        model,
        "model",
        registered_model_name="catalog.schema.churn_predictor"  # Unity Catalog path
    )

# Option 2: Register an existing run's model after the fact
result = mlflow.register_model(
    model_uri="runs:/abc123def456/model",
    name="catalog.schema.churn_predictor"
)
print(f"Registered version: {result.version}")

Managing Model Versions

from mlflow import MlflowClient

client = MlflowClient()

# Get the latest version in Production
latest_prod = client.get_latest_versions(
    name="catalog.schema.churn_predictor",
    stages=["Production"]
)
print(f"Current production model: v{latest_prod[0].version}")

# Promote a new version to Production
# (Unity Catalog uses aliases instead of stages)
client.set_registered_model_alias(
    name="catalog.schema.churn_predictor",
    alias="champion",
    version=3
)

# Load the champion model
import mlflow.pyfunc
model = mlflow.pyfunc.load_model("models:/catalog.schema.churn_predictor@champion")

Classic vs Unity Catalog Model Registry

Feature Classic (Workspace) Registry Unity Catalog Registry
Namespace Flat names within workspace (churn_model) Three-level namespace (catalog.schema.churn_model)
Governance Workspace-level ACLs only Full Unity Catalog permissions (GRANT/REVOKE)
Cross-workspace Models stuck in one workspace Models accessible across all workspaces in the account
Lineage Basic run tracking Full data-to-model lineage (which tables trained this model?)
Stage transitions None / Staging / Production / Archived Aliases (champion, challenger, etc.) — more flexible
Access control Workspace-level permissions Fine-grained: per-catalog, per-schema, or per-model
Status Being deprecated Current recommended approach
The classic workspace Model Registry (without Unity Catalog) is being deprecated. Databricks is actively migrating customers to Unity Catalog. If you're starting a new project, always use Unity Catalog for model registration (catalog.schema.model_name). If you have existing models in the workspace registry, plan a migration — Databricks provides upgrade tools to move them over.

Model Serving

You've trained a model, tracked the experiment, and registered the winning version. Now you need to get predictions out of it. On Databricks, you have two main options: real-time serving (REST API endpoint) and batch inference (score a whole table at once).

Real-Time Model Serving

Databricks Model Serving creates a serverless REST API endpoint for your registered model. You send it a JSON payload, it returns predictions in milliseconds. No infrastructure to manage — Databricks handles scaling, load balancing, and GPU allocation.

To create a serving endpoint:

  1. Go to Serving in the left sidebar
  2. Click "Create serving endpoint"
  3. Select your registered model and version (or alias like "champion")
  4. Choose compute size (Small / Medium / Large, with optional GPU)
  5. Databricks provisions the endpoint — it's ready in minutes

Once the endpoint is live, query it with any HTTP client:

import requests
import json

# Your Databricks workspace URL and personal access token
workspace_url = "https://your-workspace.cloud.databricks.com"
token = dbutils.secrets.get(scope="ml-serving", key="token")

# The endpoint name matches what you configured in the UI
endpoint_name = "churn-predictor-endpoint"

# Prepare the request payload
payload = {
    "dataframe_records": [
        {
            "tenure_months": 24,
            "monthly_charges": 79.99,
            "total_charges": 1919.76,
            "contract_type": "month-to-month",
            "payment_method": "electronic_check"
        }
    ]
}

# Call the endpoint
response = requests.post(
    f"{workspace_url}/serving-endpoints/{endpoint_name}/invocations",
    headers={
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    },
    json=payload
)

result = response.json()
print(f"Churn probability: {result['predictions'][0]:.4f}")

Batch Inference with spark_udf()

Real-time endpoints are great for one-at-a-time predictions (user clicks a button, app needs a score). But what if you need to score 10 million customers overnight? That's batch inference, and spark_udf() is the tool for it:

import mlflow

# Load the champion model as a Spark UDF
predict_udf = mlflow.pyfunc.spark_udf(
    spark,
    model_uri="models:/catalog.schema.churn_predictor@champion",
    result_type="double"
)

# Score an entire table in parallel across the cluster
scored_df = (
    spark.table("catalog.schema.customer_features")
    .withColumn("churn_probability", predict_udf())
)

# Write results back to a Delta table
scored_df.write.mode("overwrite").saveAsTable("catalog.schema.churn_scores")

When to Use Each

Aspect Real-Time Serving (REST API) Batch Inference (spark_udf)
Latency Milliseconds per request Minutes to hours for full table
Use case User-facing apps, chatbots, fraud detection Nightly scoring, report generation, feature pipelines
Volume One record (or small batch) at a time Millions to billions of records
Cost Pay for always-on endpoint (or scale to zero) Pay for cluster time only while job runs
Infrastructure Serverless — Databricks manages everything Uses your existing Spark cluster
Freshness Real-time, up-to-the-second features As fresh as your last batch run
For batch scoring (e.g., score all customers nightly), use spark_udf(). It's dramatically cheaper than a serving endpoint because you only pay for cluster time while the job is running. Reserve real-time endpoints for use cases where latency matters — a user waiting for a recommendation, a transaction being checked for fraud, or an API call that needs sub-second response.

Feature Store

Here's a problem you'll hit sooner or later: your data scientist computes features one way during training (in a notebook, with a specific query), and the engineer who builds the production pipeline computes them slightly differently (different join logic, different aggregation window, a subtle timezone bug). The model looks great in testing and terrible in production. This is called training-serving skew, and it's one of the most common ML production failures.

The Databricks Feature Store (now part of Unity Catalog as Feature Engineering in Unity Catalog) solves this by making features a first-class, reusable asset. You define a feature once, store it in a governed table, and both training and serving pull from the same source.

Creating and Using Feature Tables

from databricks.feature_engineering import FeatureEngineeringClient

fe = FeatureEngineeringClient()

# Step 1: Compute features from raw data
customer_features_df = spark.sql("""
    SELECT
        customer_id,
        COUNT(*) AS total_orders,
        AVG(order_amount) AS avg_order_value,
        DATEDIFF(current_date(), MAX(order_date)) AS days_since_last_order,
        SUM(CASE WHEN returned = true THEN 1 ELSE 0 END) AS total_returns
    FROM catalog.schema.orders
    GROUP BY customer_id
""")

# Step 2: Create (or update) a feature table in Unity Catalog
fe.create_table(
    name="catalog.schema.customer_features",
    primary_keys=["customer_id"],
    df=customer_features_df,
    description="Customer behavioral features for churn prediction"
)

# Step 3: Train a model using features from the Feature Store
from databricks.feature_engineering import FeatureLookup

training_set = fe.create_training_set(
    df=spark.table("catalog.schema.churn_labels"),  # just customer_id + label
    feature_lookups=[
        FeatureLookup(
            table_name="catalog.schema.customer_features",
            lookup_key="customer_id"
        )
    ],
    label="churned"
)

# Convert to Pandas for sklearn training
training_df = training_set.load_df().toPandas()
Feature Store ensures training and serving use the same features. When you log a model trained with Feature Store lookups, MLflow records which feature tables were used. At serving time, the model endpoint automatically fetches the latest features from those same tables. No manual feature engineering in your serving code, no risk of computing features differently. One source of truth.

MLflow vs Alternatives

MLflow isn't the only ML lifecycle platform. Here's how it stacks up against the major competitors. The right choice depends on your cloud provider, your team's workflow, and whether you value openness over managed convenience.

Feature MLflow (Databricks) Weights & Biases AWS SageMaker Google Vertex AI
Experiment tracking Built-in, autolog for 8+ frameworks Excellent UI, real-time dashboards, sweeps SageMaker Experiments (more manual) Vertex Experiments (TensorBoard-based)
Model registry Unity Catalog integration, aliases, lineage W&B Registry (newer, growing) SageMaker Model Registry Vertex Model Registry
Model serving Serverless endpoints, GPU, scale-to-zero No native serving (integrates with others) SageMaker Endpoints (robust, many options) Vertex Endpoints (autoscaling)
Open source Yes — Apache 2.0 license Client is open, server is proprietary No No
Cloud lock-in Low — works anywhere (Databricks, local, any cloud) Low — SaaS, cloud-agnostic High — AWS only High — GCP only
Pricing Free (open source) or included with Databricks Free tier, then per-user SaaS pricing Pay per compute (training + hosting hours) Pay per compute (training + prediction units)
Best for Databricks users, teams wanting open standards Research teams, experiment-heavy workflows AWS-native shops with deep SageMaker investment GCP-native shops using TFX/BigQuery ML

Bottom line: If you're on Databricks, MLflow is the obvious choice — it's pre-installed, deeply integrated, and you're already paying for it. If you need best-in-class experiment visualization and your team values a polished UI, consider W&B alongside MLflow. If you're all-in on a single cloud and want a fully managed solution, SageMaker or Vertex AI eliminate the need to think about infrastructure.

Test Yourself

Q: What are the 4 components of MLflow, and which two are most important on Databricks?

The four components are MLflow Tracking (log experiments), MLflow Models (standard packaging format), Model Registry (version and promote models), and MLflow Projects (reproducible runs). On Databricks, Tracking and the Model Registry are the most important because Databricks notebooks already handle reproducibility (making Projects less necessary).

Q: What does mlflow.autolog() do and which frameworks does it support?

mlflow.autolog() automatically captures parameters, metrics, model artifacts, and signatures whenever you call .fit() on a supported framework. It supports scikit-learn, XGBoost, LightGBM, PyTorch, TensorFlow/Keras, Spark ML, Statsmodels, and more. You don't need to add manual log_param() or log_metric() calls — just put mlflow.autolog() at the top of your notebook.

Q: What's the difference between logging a model and registering a model?

Logging a model (mlflow.sklearn.log_model()) saves the model artifact as part of an experiment run. It's like saving your work. Registering a model (mlflow.register_model()) adds it to the Model Registry with a name and version number, making it available for stage transitions (Staging, Production, Archived) and serving. You log many models during experimentation; you register only the ones you want to promote toward production.

Q: When would you use Model Serving vs batch inference with spark_udf()?

Use Model Serving (REST API endpoint) when you need low-latency, real-time predictions for individual requests — user-facing apps, fraud detection, chatbot responses. Use batch inference with spark_udf() when you need to score large volumes of data (millions of rows) at once — nightly customer scoring, feature pipelines, reports. Batch is much cheaper because you only pay for cluster time while the job runs, whereas serving endpoints incur cost as long as they're running.

Q: Why is Unity Catalog Model Registry preferred over the classic workspace registry?

Unity Catalog Model Registry provides three-level namespacing (catalog.schema.model), cross-workspace access (models aren't stuck in one workspace), fine-grained governance (GRANT/REVOKE at catalog, schema, or model level), full data-to-model lineage (which tables trained this model?), and flexible aliases (champion/challenger instead of rigid stages). The classic workspace registry is being deprecated by Databricks.

Interview Questions

Q: Walk through the MLflow model lifecycle from training to production on Databricks. What happens at each stage?

1. Experiment tracking: A data scientist trains multiple model variants in a Databricks notebook. mlflow.autolog() captures all parameters, metrics, and artifacts for each run. The scientist compares runs in the Experiment UI and identifies the best performer.

2. Model logging: The best model is logged with a signature and input example using mlflow.sklearn.log_model() (or the appropriate flavor).

3. Model registration: The model is registered in Unity Catalog with registered_model_name="catalog.schema.model_name". This creates version 1 in the registry.

4. Validation: The new version is assigned the "challenger" alias. Automated tests run — checking performance against a holdout set, verifying the signature, ensuring latency requirements are met.

5. Promotion: After passing validation, the alias is updated to "champion" via client.set_registered_model_alias(). The previous champion gets archived or becomes a fallback.

6. Serving: A Model Serving endpoint references the "champion" alias. When the alias changes, the endpoint automatically picks up the new model version. For batch use cases, a scheduled job uses spark_udf() with the "champion" alias.

7. Monitoring: Production metrics (latency, prediction distribution, data drift) are tracked. If degradation is detected, the team can roll back by reassigning the "champion" alias to a previous version.

Q: How would you implement A/B testing with MLflow model serving on Databricks?

Databricks Model Serving supports traffic splitting natively. You configure a serving endpoint with multiple model versions and assign a percentage of traffic to each:

1. Register your new model version (v2) alongside the current production version (v1).
2. In the serving endpoint configuration, set traffic routing: e.g., 90% to v1 (champion), 10% to v2 (challenger).
3. The endpoint returns a header indicating which model served each request, allowing you to join predictions with outcomes.
4. After collecting enough data, run a statistical test (chi-squared for classification, t-test for regression) comparing conversion rates or accuracy metrics between v1 and v2.
5. If v2 wins, gradually increase its traffic to 50%, then 100%. If it loses, remove it and archive that version.

Alternatively, for batch A/B testing, score users with both models using spark_udf(), write both predictions to a table, and randomly assign which prediction to show to each user.

Q: What is training-serving skew and how does the Databricks Feature Store prevent it?

Training-serving skew occurs when the features used during model training differ from the features used during production inference. Common causes include: different SQL joins, different aggregation windows, timezone mismatches, stale cached features, or subtle data type differences between the training notebook and the serving pipeline.

The Feature Store prevents this by making features a single source of truth. When you train a model using FeatureLookup, MLflow records which feature tables and keys were used. At serving time, the model endpoint automatically fetches features from those same Unity Catalog tables using the same primary key lookup. This eliminates the "two different codepaths" problem entirely — both training and serving execute the same feature retrieval logic against the same governed tables.

Q: How do you handle model rollback if a production model starts degrading?

With Unity Catalog Model Registry, rollback is fast and safe:

1. Detection: Monitor prediction distributions, latency, and business metrics (conversion rate, error rate). Set up alerts for drift using Databricks Lakehouse Monitoring or custom checks.
2. Immediate rollback: Reassign the "champion" alias to the previous known-good version: client.set_registered_model_alias("catalog.schema.model", "champion", previous_version). For serving endpoints, this takes effect within seconds — no redeployment needed.
3. For batch jobs: Since spark_udf() resolves the alias at runtime, the next batch run automatically uses the rolled-back version.
4. Root cause analysis: Compare the degraded model's training data, feature distributions, and parameters against the previous version. Check for data drift in upstream tables, schema changes, or recent feature store updates.
5. Prevention: Implement automated validation gates — before promoting any new model to "champion," run it against a holdout set and verify performance meets a minimum threshold. Use Databricks workflows to automate this check.