diff --git a/cookbooks/mlflow.mdx b/cookbooks/mlflow.mdx
new file mode 100644
index 0000000..7790e9a
--- /dev/null
+++ b/cookbooks/mlflow.mdx
@@ -0,0 +1,641 @@
+---
+title: "TabPFN with MLflow"
+description: "Learn how to wrap TabPFN as an MLflow PythonModel, register it to Unity Catalog, and deploy it to a Mosaic AI Model serving endpoint."
+---
+
+## Overview
+
+[TabPFN](https://github.com/priorlabs/tabpfn) is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable.
+
+The promise: **one registered model** that handles both classification and regression across any tabular dataset.
+
+
+
+ Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome.
+
+
+ Use a single model for classification and regression across any tabular dataset - from a notebook or SQL `ai_query()`.
+
+
+
+By the end of this tutorial you will have:
+
+- Wrapped TabPFN as an MLflow `PythonModel`
+- Registered it to Unity Catalog with a `champion` alias
+- Tested it locally on classification and regression tasks
+- Run an end-to-end example on the Lending Club loan dataset
+- Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint
+
+
+ For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the [TabPFN GitHub repository](https://github.com/priorlabs/tabpfn). For broader context on TabPFN and Databricks, read the [Databricks blog post](https://www.databricks.com/blog/tappfn-ai-accelerates-business-transformation-databricks).
+
+
+---
+
+## Prerequisites
+
+- A Databricks workspace with Unity Catalog enabled
+- Access to a **serverless GPU environment** with at least an **NVIDIA A10G** GPU
+- A TabPFN token from [Prior Labs](https://github.com/priorlabs/tabpfn?tab=readme-ov-file#installation--setup)
+
+
+ TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select **Serverless** with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically.
+
+
+---
+
+## Step 1: Install Dependencies
+
+Run the following in your Databricks notebook cell:
+
+```python
+%pip install mlflow tabpfn
+dbutils.library.restartPython()
+```
+
+---
+
+## Step 2: Configure Your TabPFN Token
+
+TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook.
+
+Run these commands in the **Databricks CLI**:
+
+```bash
+databricks secrets create-scope tabpfn
+databricks secrets put-secret tabpfn tabpfn_token --string-value ""
+```
+
+Then reference the secret in your notebook:
+
+```python
+import os
+
+os.environ["TABPFN_TOKEN"] = dbutils.secrets.get(scope="tabpfn", key="tabpfn_token")
+```
+
+---
+
+## Step 3: Define the Wrapper and Signature
+
+The `TabPFNWrapper` is a custom `mlflow.pyfunc.PythonModel`. It is the heart of this integration and handles three concerns:
+
+- **Dual-format input** - accepts both raw Python objects (from notebooks) and JSON strings (from SQL `ai_query()`)
+- **Task routing** - classification or regression, controlled by `task_config`
+- **Flexible output** - class labels, probabilities, or regression predictions based on `output_type`
+
+All input columns use `DataType.string` so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The `_maybe_parse_json()` helper transparently handles both formats at predict time.
+
+```python
+import os
+import json
+import inspect
+
+import mlflow
+import numpy as np
+import pandas as pd
+
+from typing import Literal
+
+from mlflow.models.signature import ModelSignature
+from mlflow.types.schema import (
+ Array,
+ ColSpec,
+ DataType,
+ Object,
+ ParamSchema,
+ ParamSpec,
+ Property,
+ Schema
+)
+from tabpfn import TabPFNClassifier, TabPFNRegressor
+
+
+class TabPFNWrapper(mlflow.pyfunc.PythonModel):
+ """MLflow PythonModel wrapper for TabPFN"""
+
+ _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"}
+ """The model output types for classification."""
+
+ _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"}
+ """The model output types for regression."""
+
+ @staticmethod
+ def _maybe_parse_json(value: str | dict | list) -> dict | list:
+ """Helper function to parse a JSON string if needed.
+
+ Args:
+ value: The value to parse. Can be a string, dict, or list.
+
+ Returns:
+ The parsed value.
+ """
+ if isinstance(value, str):
+ return json.loads(value)
+ return value
+
+ def _get_output_type(
+ self,
+ task: Literal["classification", "regression"],
+ output_type: str | None
+ ) -> str:
+ """Get the prediction output type.
+
+ Args:
+ output_type: The output type to get.
+
+ Returns:
+ The prediction output type.
+ """
+ if output_type is not None:
+ supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES
+ if output_type not in supported_output_types:
+ raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}")
+ return output_type
+
+ # Fallback to the defaults
+ return "preds" if task == "classification" else "mean"
+
+ def _init_estimator(
+ self,
+ task: Literal["classification", "regression"],
+ config: dict
+ ) -> TabPFNClassifier | TabPFNRegressor:
+ """Initialize a TabPFN estimator.
+
+ Args:
+ task: The task to initialize the estimator for.
+ config: The configuration for the estimator.
+
+ Returns:
+ The initialized estimator.
+ """
+ Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor
+
+ # Validate provided config keys against the estimator constructor
+ sig = inspect.signature(Estimator.__init__)
+
+ constructor_params = set(sig.parameters.keys()) - {"self"}
+ supplied_keys = set(config.keys())
+ invalid_keys = supplied_keys - constructor_params
+
+ # Raise an error if any invalid keys are provided
+ if invalid_keys:
+ msg = (
+ f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n"
+ f"Allowed parameters: {sorted(constructor_params)}"
+ )
+ raise ValueError(msg)
+
+ return Estimator(**config)
+
+ def predict(self, model_input, params=None):
+ """Run predictions.
+
+ TabPFN runs predictions in a single forward pass. The model is
+ fitted on the training data and then used to predict on the test data.
+
+ Args:
+ model_input: The input data to predict on.
+ params: The parameters for the model.
+
+ Returns:
+ The predictions.
+ """
+ # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint)
+ if isinstance(model_input, pd.DataFrame):
+ model_input = model_input.to_dict(orient="records")
+
+ assert isinstance(model_input, list), "model_input must be a list with a single row"
+ assert len(model_input) == 1, "model_input must have a single row"
+
+ model_input: dict = model_input[0]
+
+ # Parse the task configuration
+ task_config = self._maybe_parse_json(model_input["task_config"]) or {}
+
+ task = task_config.get("task")
+ if task is None:
+ raise KeyError("Task is required, must be 'classification' or 'regression'")
+
+ tabpfn_config = task_config.get("tabpfn_config") or {}
+ predict_params = task_config.get("predict_params") or {}
+
+ output_type = self._get_output_type(task, predict_params.get("output_type"))
+
+ # Parse the input data
+ X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64)
+ y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64)
+ X_test = np.array(self._maybe_parse_json(model_input["X_test"]), dtype=np.float64)
+
+ # Initialize the estimator
+ estimator = self._init_estimator(task, tabpfn_config)
+
+ # Fit the estimator
+ estimator.fit(X_train, y_train)
+
+ # Run predictions
+ if task == "classification":
+ if output_type == "probas":
+ return estimator.predict_proba(X_test).tolist()
+ return estimator.predict(X_test).tolist()
+
+ predictions = estimator.predict(X_test, output_type=output_type)
+ return predictions.tolist()
+```
+
+### Define the Model Signature
+
+The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are `DataType.string` to support both raw Python and JSON-serialized inputs from SQL `ai_query()`.
+
+```python
+# All inputs are STRING for dual-format support:
+# - JSON strings from SQL ai_query() via to_json()
+# - Raw Python objects, such as `pd.DataFrame`, from notebook calls
+
+input_schema = Schema([
+ ColSpec(DataType.string, "task_config"),
+ ColSpec(DataType.string, "X_train"),
+ ColSpec(DataType.string, "y_train"),
+ ColSpec(DataType.string, "X_test"),
+])
+
+# Output schema is required by Unity Catalog.
+# STRING is the best fit - the wrapper returns variable shapes
+# (flat list for preds, nested list for probas) and the serving
+# layer serializes whatever .tolist() produces into a JSON string.
+output_schema = Schema([ColSpec(DataType.string, name="predictions")])
+signature = ModelSignature(inputs=input_schema, outputs=output_schema)
+```
+
+
+ The `output_schema` is **required** for Unity Catalog registration. Using `DataType.string` is the right choice here because the wrapper can return either a flat list (for `preds`) or a nested list (for `probas`/regression quantiles), and the serving layer serializes the result to JSON regardless.
+
+
+---
+
+## Step 4: Register to Unity Catalog
+
+Log the model with MLflow and register it under a fully qualified Unity Catalog path (`catalog.schema.tabpfn`). The `input_example` uses `json.dumps()` to match the all-string signature - the wrapper deserializes at predict time via `_maybe_parse_json()`.
+
+After registration, we tag the latest version with a `"champion"` alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config.
+
+```python
+# Input example uses json.dumps() to match the all-string signature
+# The wrapper's _maybe_parse_json() deserializes these at predict time
+input_example = pd.DataFrame([{
+ "task_config": json.dumps({
+ "task": "classification",
+ "tabpfn_config": {
+ "n_estimators": 8,
+ "softmax_temperature": 0.9,
+ },
+ "predict_params": {
+ "output_type": "preds",
+ },
+ }),
+ "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
+ "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
+ "X_test": json.dumps([[2.0, 3.0, 0.0]]),
+}])
+
+# Fully qualified Unity Catalog path (portable across workspaces)
+CATALOG = spark.catalog.currentCatalog()
+SCHEMA = spark.catalog.currentDatabase()
+REGISTERED_MODEL_NAME = f"{CATALOG}.{SCHEMA}.tabpfn"
+
+with mlflow.start_run(run_name="tabpfn-registration") as run:
+ model_info = mlflow.pyfunc.log_model(
+ name="tabpfn",
+ python_model=TabPFNWrapper(),
+ signature=signature,
+ input_example=input_example,
+ pip_requirements=["tabpfn", "numpy", "pandas"],
+ registered_model_name=REGISTERED_MODEL_NAME,
+ )
+ print(f"Model URI: {model_info.model_uri}")
+ print(f"Run ID: {run.info.run_id}")
+```
+
+Once logged, tag the latest version as `champion`:
+
+```python
+# Tag the latest version with a "champion" alias for serving
+client = mlflow.MlflowClient()
+versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
+latest = max(versions, key=lambda v: int(v.version))
+client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
+print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}")
+```
+
+
+ Using the `champion` alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change.
+
+
+---
+
+## Step 5: Test Locally
+
+Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call `.predict()`.
+
+
+
+ ```python
+ # Load the registered model from MLflow
+ loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
+
+ # DataFrame with raw Python objects, the notebook-friendly example
+ predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": {"task": "classification"},
+ "X_train": [
+ [1.0, 2.0, 0.0],
+ [3.0, 4.0, 1.0],
+ [5.0, 6.0, 0.0],
+ [7.0, 8.0, 1.0]
+ ],
+ "y_train": [0.0, 1.0, 0.0, 1.0],
+ "X_test": [
+ [2.0, 3.0, 0.0]
+ ],
+ }]))
+ predictions
+ ```
+
+
+ ```python
+ loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
+
+ predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": json.dumps({"task": "classification", "predict_params": {"output_type": "probas"}}),
+ "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
+ "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
+ "X_test": json.dumps([[2.0, 3.0, 0.0]]),
+ }]))
+ predictions
+ ```
+
+
+ ```python
+ loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
+
+ predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": {"task": "regression", "predict_params": {"output_type": "mean"}},
+ "X_train": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
+ "y_train": [1.5, 3.5, 5.5, 7.5],
+ "X_test": [[2.0, 3.0]],
+ }]))
+ predictions
+ ```
+
+
+
+The wrapper's `_maybe_parse_json()` transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint.
+
+---
+
+## Step 6: End-to-End Example - Lending Club Loan Data
+
+Let's run a real-world classification task: predicting loan default (Good vs Bad) on the [Lending Club Q2 2018](https://www.kaggle.com/datasets/wordsforthewise/lending-club) dataset. This dataset ships with every Databricks workspace at `/databricks-datasets/`, so you can run this without any additional data download.
+
+### Load and Prepare the Data
+
+```python
+from sklearn.model_selection import train_test_split
+
+# Load Lending Club Q2 2018 (ships with every Databricks workspace)
+df = (
+ spark.read.csv(
+ "/databricks-datasets/lending-club-loan-stats/LoanStats_2018Q2.csv",
+ header=True, inferSchema=True,
+ )
+ .select(
+ "loan_status",
+ "loan_amnt", "funded_amnt", "installment", "annual_inc", "dti",
+ "open_acc", "revol_bal", "total_acc", "delinq_2yrs", "inq_last_6mths",
+ "pub_rec", "mort_acc", "tot_cur_bal", "total_pymnt", "last_pymnt_amnt",
+ )
+ .dropna(subset=["loan_status"])
+ .toPandas()
+)
+
+# Binary target: Good (0) vs Bad (1)
+df["target"] = (df["loan_status"].apply(
+ lambda s: 0 if s in ("Current", "Fully Paid") else 1
+).astype(int))
+df = df.drop(columns=["loan_status"]).dropna()
+
+# Sample and split
+df_sample = df.sample(n=6_000, random_state=42)
+X_train, X_test, y_train, y_test = train_test_split(
+ df_sample.drop(columns=["target"]), df_sample["target"],
+ test_size=0.2, random_state=42, stratify=df_sample["target"],
+)
+
+print(f"Train: {len(X_train):,} × {X_train.shape[1]} (bad-loan rate: {y_train.mean():.2%})")
+print(f"Test: {len(X_test):,} × {X_test.shape[1]}")
+```
+
+### Run Predictions
+
+Pass the full dataset through the registered MLflow model in a single call. We request `probas` so we can compute ROC-AUC alongside accuracy and F1.
+
+```python
+# Predict via the registered MLflow model
+predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": {
+ "task": "classification",
+ "predict_params": {"output_type": "probas"},
+ },
+ "X_train": X_train.values.tolist(),
+ "y_train": y_train.values.tolist(),
+ "X_test": X_test.values.tolist(),
+}]))
+
+probas = np.array(predictions)
+y_pred = probas.argmax(axis=1)
+
+print(f"Predicted {len(y_pred):,} samples")
+```
+
+### Evaluate Results
+
+```python
+from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
+
+print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
+print(f"F1 (weighted): {f1_score(y_test, y_pred, average='weighted'):.4f}")
+print(f"ROC-AUC: {roc_auc_score(y_test, probas[:, 1]):.4f}")
+
+# Materialize as a Spark DataFrame
+results_sdf = spark.createDataFrame(pd.DataFrame({
+ "actual": y_test.values,
+ "predicted": y_pred,
+ "probability_bad_loan": np.round(probas[:, 1], 4),
+}))
+
+display(results_sdf)
+```
+
+No fine-tuning, no feature engineering pipeline, no hyperparameter search - TabPFN fits and predicts in a single forward pass.
+
+---
+
+## Step 7: Deploy to Mosaic AI Model Serving
+
+Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration.
+
+```python
+import mlflow.deployments
+
+# Resolve the "champion" alias to a version number
+client = mlflow.MlflowClient()
+champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion")
+print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}")
+
+# Get the deployment MLflow client
+client = mlflow.deployments.get_deploy_client("databricks")
+
+# Create the endpoint, will return immediately and continue initializing the endpoint async
+# Check for the status in your Databricks console
+endpoint = client.create_endpoint(
+ name="tabpfn-endpoint",
+ config={
+ "served_entities": [{
+ "entity_name": REGISTERED_MODEL_NAME,
+ "entity_version": str(champion.version),
+ "workload_size": "Medium",
+ "workload_type": "GPU_MEDIUM",
+ "scale_to_zero_enabled": True,
+ "environment_vars": {
+ "TABPFN_TOKEN": "{{secrets/tabpfn/tabpfn_token}}",
+ },
+ }],
+ },
+)
+
+print(f"Endpoint created: {endpoint['name']}")
+```
+
+
+ `create_endpoint` returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under **Serving** → **tabpfn-endpoint**.
+
+
+### Calling the Endpoint
+
+Once the endpoint is live, you can reach it from Python, REST, or SQL:
+
+
+
+ ```python
+ import mlflow.deployments
+
+ client = mlflow.deployments.get_deploy_client("databricks")
+ response = client.predict(
+ endpoint="tabpfn-endpoint",
+ inputs={"dataframe_records": [{
+ "task_config": json.dumps({"task": "classification"}),
+ "X_train": json.dumps([[1.0, 2.0], [3.0, 4.0]]),
+ "y_train": json.dumps([0.0, 1.0]),
+ "X_test": json.dumps([[2.0, 3.0]]),
+ }]},
+ )
+ print(response)
+ ```
+
+
+ ```bash
+ curl -X POST \
+ https:///serving-endpoints/tabpfn-endpoint/invocations \
+ -H "Authorization: Bearer $DATABRICKS_TOKEN" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "dataframe_records": [{
+ "task_config": "{\"task\": \"classification\"}",
+ "X_train": "[[1.0, 2.0], [3.0, 4.0]]",
+ "y_train": "[0.0, 1.0]",
+ "X_test": "[[2.0, 3.0]]"
+ }]
+ }'
+ ```
+
+
+ ```sql
+ SELECT ai_query(
+ 'tabpfn-endpoint',
+ named_struct(
+ 'task_config', to_json(named_struct('task', 'classification')),
+ 'X_train', to_json(array(array(1.0, 2.0), array(3.0, 4.0))),
+ 'y_train', to_json(array(0.0, 1.0)),
+ 'X_test', to_json(array(array(2.0, 3.0)))
+ )
+ ) AS prediction
+ ```
+
+
+
+---
+
+## How the Input Format Works
+
+Understanding `task_config` is key to using the endpoint effectively. It controls both what TabPFN does and how it does it.
+
+| Field | Type | Description |
+|---|---|---|
+| `task` | `string` | Required. `"classification"` or `"regression"` |
+| `tabpfn_config` | `object` | Optional. Passed directly to `TabPFNClassifier` / `TabPFNRegressor` constructor (e.g. `n_estimators`, `softmax_temperature`) |
+| `predict_params` | `object` | Optional. Controls output format via `output_type` |
+
+### Output Types
+
+
+
+ | `output_type` | Description |
+ |---|---|
+ | `preds` (default) | Predicted class labels |
+ | `probas` | Class probabilities (nested list) |
+
+
+ | `output_type` | Description |
+ |---|---|
+ | `mean` (default) | Mean prediction |
+ | `mode` | Mode of predictive distribution |
+ | `median` | Median prediction |
+ | `quantiles` | Quantile predictions |
+ | `main` | Main summary statistics |
+ | `full` | Full predictive distribution |
+
+
+
+---
+
+## Promoting a New Model Version
+
+When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to `champion` with no configuration change needed.
+
+```python
+# After re-running the registration cell with a new TabPFNWrapper...
+client = mlflow.MlflowClient()
+versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
+latest = max(versions, key=lambda v: int(v.version))
+client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
+print(f"Promoted version {latest.version} to 'champion'")
+```
+
+---
+
+## Next Steps
+
+
+
+ Explore all supported parameters, estimator options, and output types.
+
+
+ Learn how TabPFN accelerates business transformation on Databricks.
+
+
+ Understand the `mlflow.pyfunc.PythonModel` interface used by the wrapper.
+
+
+ Configure autoscaling, traffic splitting, and monitoring for your endpoint.
+
+
diff --git a/docs.json b/docs.json
index 7b3c6d7..1e53f52 100644
--- a/docs.json
+++ b/docs.json
@@ -148,6 +148,14 @@
"integrations/sagemaker"
]
},
+ {
+
+ "group": "Cookbooks",
+ "icon": "book",
+ "pages": [
+ "cookbooks/mlflow"
+ ]
+ },
{
"group": "Use Cases",
"pages": [
diff --git a/tutorials/mlflow.mdx b/tutorials/mlflow.mdx
new file mode 100644
index 0000000..cd532d8
--- /dev/null
+++ b/tutorials/mlflow.mdx
@@ -0,0 +1,641 @@
+---
+title: "TabPFN with MLflow"
+description: "Learn how to wrap TabPFN as an MLflow PythonModel, register it to Unity Catalog, and deploy it to a Mosaic AI Model serving endpoint."
+---
+
+## Overview
+
+[TabPFN](https://github.com/priorlabs/tabpfn) is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable.
+
+The promise: **one registered model** that handles both classification and regression across any tabular dataset.
+
+
+
+ Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome.
+
+
+ Use a single model for classification and regression across any tabular dataset - from a notebook or SQL `ai_query()`.
+
+
+
+By the end of this tutorial you will have:
+
+- Wrapped TabPFN as an MLflow `PythonModel`
+- Registered it to Unity Catalog with a `champion` alias
+- Tested it locally on classification and regression tasks
+- Run an end-to-end example on the Lending Club loan dataset
+- Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint
+
+
+ For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the [TabPFN GitHub repository](https://github.com/priorlabs/tabpfn). For broader context on TabPFN and Databricks, read the [Databricks blog post](https://www.databricks.com/blog/tappfn-ai-accelerates-business-transformation-databricks).
+
+
+---
+
+## Prerequisites
+
+- A Databricks workspace with Unity Catalog enabled
+- Access to a **serverless GPU environment** with at least an **NVIDIA A10G** GPU
+- A TabPFN token from [Prior Labs](https://github.com/priorlabs/tabpfn?tab=readme-ov-file#installation--setup)
+
+
+ TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select **Serverless** with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically.
+
+
+---
+
+## Step 1: Install Dependencies
+
+Run the following in your Databricks notebook cell:
+
+```python
+%pip install mlflow tabpfn
+dbutils.library.restartPython()
+```
+
+---
+
+## Step 2: Configure Your TabPFN Token
+
+TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook.
+
+Run these commands in the **Databricks CLI**:
+
+```bash
+databricks secrets create-scope tabpfn
+databricks secrets put-secret tabpfn tabpfn_token --string-value ""
+```
+
+Then reference the secret in your notebook:
+
+```python
+import os
+
+os.environ["TABPFN_TOKEN"] = dbutils.secrets.get(scope="tabpfn", key="tabpfn_token")
+```
+
+---
+
+## Step 3: Define the Wrapper and Signature
+
+The `TabPFNWrapper` is a custom `mlflow.pyfunc.PythonModel`. It is the heart of this integration and handles three concerns:
+
+- **Dual-format input** - accepts both raw Python objects (from notebooks) and JSON strings (from SQL `ai_query()`)
+- **Task routing** - classification or regression, controlled by `task_config`
+- **Flexible output** - class labels, probabilities, or regression predictions based on `output_type`
+
+All input columns use `DataType.string` so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The `_maybe_parse_json()` helper transparently handles both formats at predict time.
+
+```python
+import os
+import json
+import inspect
+
+import mlflow
+import numpy as np
+import pandas as pd
+
+from typing import Literal
+
+from mlflow.models.signature import ModelSignature
+from mlflow.types.schema import (
+ Array,
+ ColSpec,
+ DataType,
+ Object,
+ ParamSchema,
+ ParamSpec,
+ Property,
+ Schema
+)
+from tabpfn import TabPFNClassifier, TabPFNRegressor
+
+
+class TabPFNWrapper(mlflow.pyfunc.PythonModel):
+ """MLflow PythonModel wrapper for TabPFN"""
+
+ _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"}
+ """The model output types for classification."""
+
+ _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"}
+ """The model output types for regression."""
+
+ @staticmethod
+ def _maybe_parse_json(value: str | dict | list) -> dict | list:
+ """Helper function to parse a JSON string if needed.
+
+ Args:
+ value: The value to parse. Can be a string, dict, or list.
+
+ Returns:
+ The parsed value.
+ """
+ if isinstance(value, str):
+ return json.loads(value)
+ return value
+
+ def _get_output_type(
+ self,
+ task: Literal["classification", "regression"],
+ output_type: str | None
+ ) -> str:
+ """Get the prediction output type.
+
+ Args:
+ output_type: The output type to get.
+
+ Returns:
+ The prediction output type.
+ """
+ if output_type is not None:
+ supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES
+ if output_type not in supported_output_types:
+ raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}")
+ return output_type
+
+ # Fallback to the defaults
+ return "preds" if task == "classification" else "mean"
+
+ def _init_estimator(
+ self,
+ task: Literal["classification", "regression"],
+ config: dict
+ ) -> TabPFNClassifier | TabPFNRegressor:
+ """Initialize a TabPFN estimator.
+
+ Args:
+ task: The task to initialize the estimator for.
+ config: The configuration for the estimator.
+
+ Returns:
+ The initialized estimator.
+ """
+ Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor
+
+ # Validate provided config keys against the estimator constructor
+ sig = inspect.signature(Estimator.__init__)
+
+ constructor_params = set(sig.parameters.keys()) - {"self"}
+ supplied_keys = set(config.keys())
+ invalid_keys = supplied_keys - constructor_params
+
+ # Raise an error if any invalid keys are provided
+ if invalid_keys:
+ msg = (
+ f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n"
+ f"Allowed parameters: {sorted(constructor_params)}"
+ )
+ raise ValueError(msg)
+
+ return Estimator(**config)
+
+ def predict(self, model_input, params=None):
+ """Run predictions.
+
+ TabPFN runs predictions in a single forward pass. The model is
+ fitted on the training data and then used to predict on the test data.
+
+ Args:
+ model_input: The input data to predict on.
+ params: The parameters for the model.
+
+ Returns:
+ The predictions.
+ """
+ # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint)
+ if isinstance(model_input, pd.DataFrame):
+ model_input = model_input.to_dict(orient="records")
+
+ assert isinstance(model_input, list), "model_input must be a list with a single row"
+ assert len(model_input) == 1, "model_input must have a single row"
+
+ model_input: dict = model_input[0]
+
+ # Parse the task configuration
+ task_config = self._maybe_parse_json(model_input["task_config"]) or {}
+
+ task = task_config.get("task")
+ if task is None:
+ raise KeyError("Task is required, must be 'classification' or 'regression'")
+
+ tabpfn_config = task_config.get("tabpfn_config") or {}
+ predict_params = task_config.get("predict_params") or {}
+
+ output_type = self._get_output_type(task, predict_params.get("output_type"))
+
+ # Parse the input data
+ X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64)
+ y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64)
+ X_test = np.array(self._maybe_parse_json(model_input["X_test"]), dtype=np.float64)
+
+ # Initialize the estimator
+ estimator = self._init_estimator(task, tabpfn_config)
+
+ # Fit the estimator
+ estimator.fit(X_train, y_train)
+
+ # Run predictions
+ if task == "classification":
+ if output_type == "probas":
+ return estimator.predict_proba(X_test).tolist()
+ return estimator.predict(X_test).tolist()
+
+ predictions = estimator.predict(X_test, output_type=output_type)
+ return predictions.tolist()
+```
+
+### Define the Model Signature
+
+The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are `DataType.string` to support both raw Python and JSON-serialized inputs from SQL `ai_query()`.
+
+```python
+# All inputs are STRING for dual-format support:
+# - JSON strings from SQL ai_query() via to_json()
+# - Raw Python objects, such as `pd.DataFrame`, from notebook calls
+
+input_schema = Schema([
+ ColSpec(DataType.string, "task_config"),
+ ColSpec(DataType.string, "X_train"),
+ ColSpec(DataType.string, "y_train"),
+ ColSpec(DataType.string, "X_test"),
+])
+
+# Output schema is required by Unity Catalog.
+# STRING is the best fit - the wrapper returns variable shapes
+# (flat list for preds, nested list for probas) and the serving
+# layer serializes whatever .tolist() produces into a JSON string.
+output_schema = Schema([ColSpec(DataType.string, name="predictions")])
+signature = ModelSignature(inputs=input_schema, outputs=output_schema)
+```
+
+
+ The `output_schema` is **required** for Unity Catalog registration. Using `DataType.string` is the right choice here because the wrapper can return either a flat list (for `preds`) or a nested list (for `probas`/regression quantiles), and the serving layer serializes the result to JSON regardless.
+
+
+---
+
+## Step 4: Register to Unity Catalog
+
+Log the model with MLflow and register it under a fully qualified Unity Catalog path (`catalog.schema.tabpfn`). The `input_example` uses `json.dumps()` to match the all-string signature - the wrapper deserializes at predict time via `_maybe_parse_json()`.
+
+After registration, we tag the latest version with a `"champion"` alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config.
+
+```python
+# Input example uses json.dumps() to match the all-string signature
+# The wrapper's _maybe_parse_json() deserializes these at predict time
+input_example = pd.DataFrame([{
+ "task_config": json.dumps({
+ "task": "classification",
+ "tabpfn_config": {
+ "n_estimators": 8,
+ "softmax_temperature": 0.9,
+ },
+ "predict_params": {
+ "output_type": "preds",
+ },
+ }),
+ "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
+ "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
+ "X_test": json.dumps([[2.0, 3.0, 0.0]]),
+}])
+
+# Fully qualified Unity Catalog path (portable across workspaces)
+CATALOG = spark.catalog.currentCatalog()
+SCHEMA = spark.catalog.currentDatabase()
+REGISTERED_MODEL_NAME = f"{CATALOG}.{SCHEMA}.tabpfn"
+
+with mlflow.start_run(run_name="tabpfn-registration") as run:
+ model_info = mlflow.pyfunc.log_model(
+ name="tabpfn",
+ python_model=TabPFNWrapper(),
+ signature=signature,
+ input_example=input_example,
+ pip_requirements=["tabpfn", "numpy", "pandas"],
+ registered_model_name=REGISTERED_MODEL_NAME,
+ )
+ print(f"Model URI: {model_info.model_uri}")
+ print(f"Run ID: {run.info.run_id}")
+```
+
+Once logged, tag the latest version as `champion`:
+
+```python
+# Tag the latest version with a "champion" alias for serving
+client = mlflow.MlflowClient()
+versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
+latest = max(versions, key=lambda v: int(v.version))
+client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
+print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}")
+```
+
+
+ Using the `champion` alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change.
+
+
+---
+
+## Step 5: Test Locally
+
+Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call `.predict()`.
+
+
+
+ ```python
+ # Load the registered model from MLflow
+ loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
+
+ # DataFrame with raw Python objects, the notebook-friendly example
+ predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": {"task": "classification"},
+ "X_train": [
+ [1.0, 2.0, 0.0],
+ [3.0, 4.0, 1.0],
+ [5.0, 6.0, 0.0],
+ [7.0, 8.0, 1.0]
+ ],
+ "y_train": [0.0, 1.0, 0.0, 1.0],
+ "X_test": [
+ [2.0, 3.0, 0.0]
+ ],
+ }]))
+ predictions
+ ```
+
+
+ ```python
+ loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
+
+ predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": json.dumps({"task": "classification", "predict_params": {"output_type": "probas"}}),
+ "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
+ "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
+ "X_test": json.dumps([[2.0, 3.0, 0.0]]),
+ }]))
+ predictions
+ ```
+
+
+ ```python
+ loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
+
+ predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": {"task": "regression", "predict_params": {"output_type": "mean"}},
+ "X_train": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
+ "y_train": [1.5, 3.5, 5.5, 7.5],
+ "X_test": [[2.0, 3.0]],
+ }]))
+ predictions
+ ```
+
+
+
+The wrapper's `_maybe_parse_json()` transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint.
+
+---
+
+## Step 6: End-to-End Example - Lending Club Loan Data
+
+Let's run a real-world classification task: predicting loan default (Good vs Bad) on the [Lending Club Q2 2018](https://www.kaggle.com/datasets/wordsforthewise/lending-club) dataset. This dataset ships with every Databricks workspace at `/databricks-datasets/`, so you can run this without any additional data download.
+
+### Load and Prepare the Data
+
+```python
+from sklearn.model_selection import train_test_split
+
+# Load Lending Club Q2 2018 (ships with every Databricks workspace)
+df = (
+ spark.read.csv(
+ "/databricks-datasets/lending-club-loan-stats/LoanStats_2018Q2.csv",
+ header=True, inferSchema=True,
+ )
+ .select(
+ "loan_status",
+ "loan_amnt", "funded_amnt", "installment", "annual_inc", "dti",
+ "open_acc", "revol_bal", "total_acc", "delinq_2yrs", "inq_last_6mths",
+ "pub_rec", "mort_acc", "tot_cur_bal", "total_pymnt", "last_pymnt_amnt",
+ )
+ .dropna(subset=["loan_status"])
+ .toPandas()
+)
+
+# Binary target: Good (0) vs Bad (1)
+df["target"] = (df["loan_status"].apply(
+ lambda s: 0 if s in ("Current", "Fully Paid") else 1
+).astype(int))
+df = df.drop(columns=["loan_status"]).dropna()
+
+# Sample and split
+df_sample = df.sample(n=6_000, random_state=42)
+X_train, X_test, y_train, y_test = train_test_split(
+ df_sample.drop(columns=["target"]), df_sample["target"],
+ test_size=0.2, random_state=42, stratify=df_sample["target"],
+)
+
+print(f"Train: {len(X_train):,} × {X_train.shape[1]} (bad-loan rate: {y_train.mean():.2%})")
+print(f"Test: {len(X_test):,} × {X_test.shape[1]}")
+```
+
+### Run Predictions
+
+Pass the full dataset through the registered MLflow model in a single call. We request `probas` so we can compute ROC-AUC alongside accuracy and F1.
+
+```python
+# Predict via the registered MLflow model
+predictions = loaded_model.predict(pd.DataFrame([{
+ "task_config": {
+ "task": "classification",
+ "predict_params": {"output_type": "probas"},
+ },
+ "X_train": X_train.values.tolist(),
+ "y_train": y_train.values.tolist(),
+ "X_test": X_test.values.tolist(),
+}]))
+
+probas = np.array(predictions)
+y_pred = probas.argmax(axis=1)
+
+print(f"Predicted {len(y_pred):,} samples")
+```
+
+### Evaluate Results
+
+```python
+from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
+
+print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
+print(f"F1 (weighted): {f1_score(y_test, y_pred, average='weighted'):.4f}")
+print(f"ROC-AUC: {roc_auc_score(y_test, probas[:, 1]):.4f}")
+
+# Materialize as a Spark DataFrame
+results_sdf = spark.createDataFrame(pd.DataFrame({
+ "actual": y_test.values,
+ "predicted": y_pred,
+ "probability_bad_loan": np.round(probas[:, 1], 4),
+}))
+
+display(results_sdf)
+```
+
+No fine-tuning, no feature engineering pipeline, no hyperparameter search - TabPFN fits and predicts in a single forward pass.
+
+---
+
+## Step 7: Deploy to Mosaic AI Model Serving
+
+Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration.
+
+```python
+import mlflow.deployments
+
+# Resolve the "champion" alias to a version number
+client = mlflow.MlflowClient()
+champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion")
+print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}")
+
+# Get the deployment MLflow client
+client = mlflow.deployments.get_deploy_client("databricks")
+
+# Create the endpoint, will return immediately and continue initializing the endpoint async
+# Check for the status in your Databricks console
+endpoint = client.create_endpoint(
+ name="tabpfn-endpoint",
+ config={
+ "served_entities": [{
+ "entity_name": REGISTERED_MODEL_NAME,
+ "entity_version": str(champion.version),
+ "workload_size": "Medium",
+ "workload_type": "GPU_MEDIUM",
+ "scale_to_zero_enabled": True,
+ "environment_vars": {
+ "HF_TOKEN": "{{secrets/tabpfn/hf_token}}",
+ },
+ }],
+ },
+)
+
+print(f"Endpoint created: {endpoint['name']}")
+```
+
+
+ `create_endpoint` returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under **Serving** → **tabpfn-endpoint**.
+
+
+### Calling the Endpoint
+
+Once the endpoint is live, you can reach it from Python, REST, or SQL:
+
+
+
+ ```python
+ import mlflow.deployments
+
+ client = mlflow.deployments.get_deploy_client("databricks")
+ response = client.predict(
+ endpoint="tabpfn-endpoint",
+ inputs={"dataframe_records": [{
+ "task_config": json.dumps({"task": "classification"}),
+ "X_train": json.dumps([[1.0, 2.0], [3.0, 4.0]]),
+ "y_train": json.dumps([0.0, 1.0]),
+ "X_test": json.dumps([[2.0, 3.0]]),
+ }]},
+ )
+ print(response)
+ ```
+
+
+ ```bash
+ curl -X POST \
+ https:///serving-endpoints/tabpfn-endpoint/invocations \
+ -H "Authorization: Bearer $DATABRICKS_TOKEN" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "dataframe_records": [{
+ "task_config": "{\"task\": \"classification\"}",
+ "X_train": "[[1.0, 2.0], [3.0, 4.0]]",
+ "y_train": "[0.0, 1.0]",
+ "X_test": "[[2.0, 3.0]]"
+ }]
+ }'
+ ```
+
+
+ ```sql
+ SELECT ai_query(
+ 'tabpfn-endpoint',
+ named_struct(
+ 'task_config', to_json(named_struct('task', 'classification')),
+ 'X_train', to_json(array(array(1.0, 2.0), array(3.0, 4.0))),
+ 'y_train', to_json(array(0.0, 1.0)),
+ 'X_test', to_json(array(array(2.0, 3.0)))
+ )
+ ) AS prediction
+ ```
+
+
+
+---
+
+## How the Input Format Works
+
+Understanding `task_config` is key to using the endpoint effectively. It controls both what TabPFN does and how it does it.
+
+| Field | Type | Description |
+|---|---|---|
+| `task` | `string` | Required. `"classification"` or `"regression"` |
+| `tabpfn_config` | `object` | Optional. Passed directly to `TabPFNClassifier` / `TabPFNRegressor` constructor (e.g. `n_estimators`, `softmax_temperature`) |
+| `predict_params` | `object` | Optional. Controls output format via `output_type` |
+
+### Output Types
+
+
+
+ | `output_type` | Description |
+ |---|---|
+ | `preds` (default) | Predicted class labels |
+ | `probas` | Class probabilities (nested list) |
+
+
+ | `output_type` | Description |
+ |---|---|
+ | `mean` (default) | Mean prediction |
+ | `mode` | Mode of predictive distribution |
+ | `median` | Median prediction |
+ | `quantiles` | Quantile predictions |
+ | `main` | Main summary statistics |
+ | `full` | Full predictive distribution |
+
+
+
+---
+
+## Promoting a New Model Version
+
+When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to `champion` with no configuration change needed.
+
+```python
+# After re-running the registration cell with a new TabPFNWrapper...
+client = mlflow.MlflowClient()
+versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
+latest = max(versions, key=lambda v: int(v.version))
+client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
+print(f"Promoted version {latest.version} to 'champion'")
+```
+
+---
+
+## Next Steps
+
+
+
+ Explore all supported parameters, estimator options, and output types.
+
+
+ Learn how TabPFN accelerates business transformation on Databricks.
+
+
+ Understand the `mlflow.pyfunc.PythonModel` interface used by the wrapper.
+
+
+ Configure autoscaling, traffic splitting, and monitoring for your endpoint.
+
+