diff --git a/cookbooks/mlflow.mdx b/cookbooks/mlflow.mdx new file mode 100644 index 0000000..7790e9a --- /dev/null +++ b/cookbooks/mlflow.mdx @@ -0,0 +1,641 @@ +--- +title: "TabPFN with MLflow" +description: "Learn how to wrap TabPFN as an MLflow PythonModel, register it to Unity Catalog, and deploy it to a Mosaic AI Model serving endpoint." +--- + +## Overview + +[TabPFN](https://github.com/priorlabs/tabpfn) is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable. + +The promise: **one registered model** that handles both classification and regression across any tabular dataset. + + + + Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome. + + + Use a single model for classification and regression across any tabular dataset - from a notebook or SQL `ai_query()`. + + + +By the end of this tutorial you will have: + +- Wrapped TabPFN as an MLflow `PythonModel` +- Registered it to Unity Catalog with a `champion` alias +- Tested it locally on classification and regression tasks +- Run an end-to-end example on the Lending Club loan dataset +- Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint + + + For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the [TabPFN GitHub repository](https://github.com/priorlabs/tabpfn). For broader context on TabPFN and Databricks, read the [Databricks blog post](https://www.databricks.com/blog/tappfn-ai-accelerates-business-transformation-databricks). + + +--- + +## Prerequisites + +- A Databricks workspace with Unity Catalog enabled +- Access to a **serverless GPU environment** with at least an **NVIDIA A10G** GPU +- A TabPFN token from [Prior Labs](https://github.com/priorlabs/tabpfn?tab=readme-ov-file#installation--setup) + + + TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select **Serverless** with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically. + + +--- + +## Step 1: Install Dependencies + +Run the following in your Databricks notebook cell: + +```python +%pip install mlflow tabpfn +dbutils.library.restartPython() +``` + +--- + +## Step 2: Configure Your TabPFN Token + +TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook. + +Run these commands in the **Databricks CLI**: + +```bash +databricks secrets create-scope tabpfn +databricks secrets put-secret tabpfn tabpfn_token --string-value "" +``` + +Then reference the secret in your notebook: + +```python +import os + +os.environ["TABPFN_TOKEN"] = dbutils.secrets.get(scope="tabpfn", key="tabpfn_token") +``` + +--- + +## Step 3: Define the Wrapper and Signature + +The `TabPFNWrapper` is a custom `mlflow.pyfunc.PythonModel`. It is the heart of this integration and handles three concerns: + +- **Dual-format input** - accepts both raw Python objects (from notebooks) and JSON strings (from SQL `ai_query()`) +- **Task routing** - classification or regression, controlled by `task_config` +- **Flexible output** - class labels, probabilities, or regression predictions based on `output_type` + +All input columns use `DataType.string` so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The `_maybe_parse_json()` helper transparently handles both formats at predict time. + +```python +import os +import json +import inspect + +import mlflow +import numpy as np +import pandas as pd + +from typing import Literal + +from mlflow.models.signature import ModelSignature +from mlflow.types.schema import ( + Array, + ColSpec, + DataType, + Object, + ParamSchema, + ParamSpec, + Property, + Schema +) +from tabpfn import TabPFNClassifier, TabPFNRegressor + + +class TabPFNWrapper(mlflow.pyfunc.PythonModel): + """MLflow PythonModel wrapper for TabPFN""" + + _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"} + """The model output types for classification.""" + + _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"} + """The model output types for regression.""" + + @staticmethod + def _maybe_parse_json(value: str | dict | list) -> dict | list: + """Helper function to parse a JSON string if needed. + + Args: + value: The value to parse. Can be a string, dict, or list. + + Returns: + The parsed value. + """ + if isinstance(value, str): + return json.loads(value) + return value + + def _get_output_type( + self, + task: Literal["classification", "regression"], + output_type: str | None + ) -> str: + """Get the prediction output type. + + Args: + output_type: The output type to get. + + Returns: + The prediction output type. + """ + if output_type is not None: + supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES + if output_type not in supported_output_types: + raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}") + return output_type + + # Fallback to the defaults + return "preds" if task == "classification" else "mean" + + def _init_estimator( + self, + task: Literal["classification", "regression"], + config: dict + ) -> TabPFNClassifier | TabPFNRegressor: + """Initialize a TabPFN estimator. + + Args: + task: The task to initialize the estimator for. + config: The configuration for the estimator. + + Returns: + The initialized estimator. + """ + Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor + + # Validate provided config keys against the estimator constructor + sig = inspect.signature(Estimator.__init__) + + constructor_params = set(sig.parameters.keys()) - {"self"} + supplied_keys = set(config.keys()) + invalid_keys = supplied_keys - constructor_params + + # Raise an error if any invalid keys are provided + if invalid_keys: + msg = ( + f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n" + f"Allowed parameters: {sorted(constructor_params)}" + ) + raise ValueError(msg) + + return Estimator(**config) + + def predict(self, model_input, params=None): + """Run predictions. + + TabPFN runs predictions in a single forward pass. The model is + fitted on the training data and then used to predict on the test data. + + Args: + model_input: The input data to predict on. + params: The parameters for the model. + + Returns: + The predictions. + """ + # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint) + if isinstance(model_input, pd.DataFrame): + model_input = model_input.to_dict(orient="records") + + assert isinstance(model_input, list), "model_input must be a list with a single row" + assert len(model_input) == 1, "model_input must have a single row" + + model_input: dict = model_input[0] + + # Parse the task configuration + task_config = self._maybe_parse_json(model_input["task_config"]) or {} + + task = task_config.get("task") + if task is None: + raise KeyError("Task is required, must be 'classification' or 'regression'") + + tabpfn_config = task_config.get("tabpfn_config") or {} + predict_params = task_config.get("predict_params") or {} + + output_type = self._get_output_type(task, predict_params.get("output_type")) + + # Parse the input data + X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64) + y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64) + X_test = np.array(self._maybe_parse_json(model_input["X_test"]), dtype=np.float64) + + # Initialize the estimator + estimator = self._init_estimator(task, tabpfn_config) + + # Fit the estimator + estimator.fit(X_train, y_train) + + # Run predictions + if task == "classification": + if output_type == "probas": + return estimator.predict_proba(X_test).tolist() + return estimator.predict(X_test).tolist() + + predictions = estimator.predict(X_test, output_type=output_type) + return predictions.tolist() +``` + +### Define the Model Signature + +The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are `DataType.string` to support both raw Python and JSON-serialized inputs from SQL `ai_query()`. + +```python +# All inputs are STRING for dual-format support: +# - JSON strings from SQL ai_query() via to_json() +# - Raw Python objects, such as `pd.DataFrame`, from notebook calls + +input_schema = Schema([ + ColSpec(DataType.string, "task_config"), + ColSpec(DataType.string, "X_train"), + ColSpec(DataType.string, "y_train"), + ColSpec(DataType.string, "X_test"), +]) + +# Output schema is required by Unity Catalog. +# STRING is the best fit - the wrapper returns variable shapes +# (flat list for preds, nested list for probas) and the serving +# layer serializes whatever .tolist() produces into a JSON string. +output_schema = Schema([ColSpec(DataType.string, name="predictions")]) +signature = ModelSignature(inputs=input_schema, outputs=output_schema) +``` + + + The `output_schema` is **required** for Unity Catalog registration. Using `DataType.string` is the right choice here because the wrapper can return either a flat list (for `preds`) or a nested list (for `probas`/regression quantiles), and the serving layer serializes the result to JSON regardless. + + +--- + +## Step 4: Register to Unity Catalog + +Log the model with MLflow and register it under a fully qualified Unity Catalog path (`catalog.schema.tabpfn`). The `input_example` uses `json.dumps()` to match the all-string signature - the wrapper deserializes at predict time via `_maybe_parse_json()`. + +After registration, we tag the latest version with a `"champion"` alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config. + +```python +# Input example uses json.dumps() to match the all-string signature +# The wrapper's _maybe_parse_json() deserializes these at predict time +input_example = pd.DataFrame([{ + "task_config": json.dumps({ + "task": "classification", + "tabpfn_config": { + "n_estimators": 8, + "softmax_temperature": 0.9, + }, + "predict_params": { + "output_type": "preds", + }, + }), + "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]), + "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]), + "X_test": json.dumps([[2.0, 3.0, 0.0]]), +}]) + +# Fully qualified Unity Catalog path (portable across workspaces) +CATALOG = spark.catalog.currentCatalog() +SCHEMA = spark.catalog.currentDatabase() +REGISTERED_MODEL_NAME = f"{CATALOG}.{SCHEMA}.tabpfn" + +with mlflow.start_run(run_name="tabpfn-registration") as run: + model_info = mlflow.pyfunc.log_model( + name="tabpfn", + python_model=TabPFNWrapper(), + signature=signature, + input_example=input_example, + pip_requirements=["tabpfn", "numpy", "pandas"], + registered_model_name=REGISTERED_MODEL_NAME, + ) + print(f"Model URI: {model_info.model_uri}") + print(f"Run ID: {run.info.run_id}") +``` + +Once logged, tag the latest version as `champion`: + +```python +# Tag the latest version with a "champion" alias for serving +client = mlflow.MlflowClient() +versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'") +latest = max(versions, key=lambda v: int(v.version)) +client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version) +print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}") +``` + + + Using the `champion` alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change. + + +--- + +## Step 5: Test Locally + +Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call `.predict()`. + + + + ```python + # Load the registered model from MLflow + loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + + # DataFrame with raw Python objects, the notebook-friendly example + predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": {"task": "classification"}, + "X_train": [ + [1.0, 2.0, 0.0], + [3.0, 4.0, 1.0], + [5.0, 6.0, 0.0], + [7.0, 8.0, 1.0] + ], + "y_train": [0.0, 1.0, 0.0, 1.0], + "X_test": [ + [2.0, 3.0, 0.0] + ], + }])) + predictions + ``` + + + ```python + loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + + predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": json.dumps({"task": "classification", "predict_params": {"output_type": "probas"}}), + "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]), + "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]), + "X_test": json.dumps([[2.0, 3.0, 0.0]]), + }])) + predictions + ``` + + + ```python + loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + + predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": {"task": "regression", "predict_params": {"output_type": "mean"}}, + "X_train": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + "y_train": [1.5, 3.5, 5.5, 7.5], + "X_test": [[2.0, 3.0]], + }])) + predictions + ``` + + + +The wrapper's `_maybe_parse_json()` transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint. + +--- + +## Step 6: End-to-End Example - Lending Club Loan Data + +Let's run a real-world classification task: predicting loan default (Good vs Bad) on the [Lending Club Q2 2018](https://www.kaggle.com/datasets/wordsforthewise/lending-club) dataset. This dataset ships with every Databricks workspace at `/databricks-datasets/`, so you can run this without any additional data download. + +### Load and Prepare the Data + +```python +from sklearn.model_selection import train_test_split + +# Load Lending Club Q2 2018 (ships with every Databricks workspace) +df = ( + spark.read.csv( + "/databricks-datasets/lending-club-loan-stats/LoanStats_2018Q2.csv", + header=True, inferSchema=True, + ) + .select( + "loan_status", + "loan_amnt", "funded_amnt", "installment", "annual_inc", "dti", + "open_acc", "revol_bal", "total_acc", "delinq_2yrs", "inq_last_6mths", + "pub_rec", "mort_acc", "tot_cur_bal", "total_pymnt", "last_pymnt_amnt", + ) + .dropna(subset=["loan_status"]) + .toPandas() +) + +# Binary target: Good (0) vs Bad (1) +df["target"] = (df["loan_status"].apply( + lambda s: 0 if s in ("Current", "Fully Paid") else 1 +).astype(int)) +df = df.drop(columns=["loan_status"]).dropna() + +# Sample and split +df_sample = df.sample(n=6_000, random_state=42) +X_train, X_test, y_train, y_test = train_test_split( + df_sample.drop(columns=["target"]), df_sample["target"], + test_size=0.2, random_state=42, stratify=df_sample["target"], +) + +print(f"Train: {len(X_train):,} × {X_train.shape[1]} (bad-loan rate: {y_train.mean():.2%})") +print(f"Test: {len(X_test):,} × {X_test.shape[1]}") +``` + +### Run Predictions + +Pass the full dataset through the registered MLflow model in a single call. We request `probas` so we can compute ROC-AUC alongside accuracy and F1. + +```python +# Predict via the registered MLflow model +predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": { + "task": "classification", + "predict_params": {"output_type": "probas"}, + }, + "X_train": X_train.values.tolist(), + "y_train": y_train.values.tolist(), + "X_test": X_test.values.tolist(), +}])) + +probas = np.array(predictions) +y_pred = probas.argmax(axis=1) + +print(f"Predicted {len(y_pred):,} samples") +``` + +### Evaluate Results + +```python +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score + +print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}") +print(f"F1 (weighted): {f1_score(y_test, y_pred, average='weighted'):.4f}") +print(f"ROC-AUC: {roc_auc_score(y_test, probas[:, 1]):.4f}") + +# Materialize as a Spark DataFrame +results_sdf = spark.createDataFrame(pd.DataFrame({ + "actual": y_test.values, + "predicted": y_pred, + "probability_bad_loan": np.round(probas[:, 1], 4), +})) + +display(results_sdf) +``` + +No fine-tuning, no feature engineering pipeline, no hyperparameter search - TabPFN fits and predicts in a single forward pass. + +--- + +## Step 7: Deploy to Mosaic AI Model Serving + +Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration. + +```python +import mlflow.deployments + +# Resolve the "champion" alias to a version number +client = mlflow.MlflowClient() +champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion") +print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}") + +# Get the deployment MLflow client +client = mlflow.deployments.get_deploy_client("databricks") + +# Create the endpoint, will return immediately and continue initializing the endpoint async +# Check for the status in your Databricks console +endpoint = client.create_endpoint( + name="tabpfn-endpoint", + config={ + "served_entities": [{ + "entity_name": REGISTERED_MODEL_NAME, + "entity_version": str(champion.version), + "workload_size": "Medium", + "workload_type": "GPU_MEDIUM", + "scale_to_zero_enabled": True, + "environment_vars": { + "TABPFN_TOKEN": "{{secrets/tabpfn/tabpfn_token}}", + }, + }], + }, +) + +print(f"Endpoint created: {endpoint['name']}") +``` + + + `create_endpoint` returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under **Serving** → **tabpfn-endpoint**. + + +### Calling the Endpoint + +Once the endpoint is live, you can reach it from Python, REST, or SQL: + + + + ```python + import mlflow.deployments + + client = mlflow.deployments.get_deploy_client("databricks") + response = client.predict( + endpoint="tabpfn-endpoint", + inputs={"dataframe_records": [{ + "task_config": json.dumps({"task": "classification"}), + "X_train": json.dumps([[1.0, 2.0], [3.0, 4.0]]), + "y_train": json.dumps([0.0, 1.0]), + "X_test": json.dumps([[2.0, 3.0]]), + }]}, + ) + print(response) + ``` + + + ```bash + curl -X POST \ + https:///serving-endpoints/tabpfn-endpoint/invocations \ + -H "Authorization: Bearer $DATABRICKS_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "dataframe_records": [{ + "task_config": "{\"task\": \"classification\"}", + "X_train": "[[1.0, 2.0], [3.0, 4.0]]", + "y_train": "[0.0, 1.0]", + "X_test": "[[2.0, 3.0]]" + }] + }' + ``` + + + ```sql + SELECT ai_query( + 'tabpfn-endpoint', + named_struct( + 'task_config', to_json(named_struct('task', 'classification')), + 'X_train', to_json(array(array(1.0, 2.0), array(3.0, 4.0))), + 'y_train', to_json(array(0.0, 1.0)), + 'X_test', to_json(array(array(2.0, 3.0))) + ) + ) AS prediction + ``` + + + +--- + +## How the Input Format Works + +Understanding `task_config` is key to using the endpoint effectively. It controls both what TabPFN does and how it does it. + +| Field | Type | Description | +|---|---|---| +| `task` | `string` | Required. `"classification"` or `"regression"` | +| `tabpfn_config` | `object` | Optional. Passed directly to `TabPFNClassifier` / `TabPFNRegressor` constructor (e.g. `n_estimators`, `softmax_temperature`) | +| `predict_params` | `object` | Optional. Controls output format via `output_type` | + +### Output Types + + + + | `output_type` | Description | + |---|---| + | `preds` (default) | Predicted class labels | + | `probas` | Class probabilities (nested list) | + + + | `output_type` | Description | + |---|---| + | `mean` (default) | Mean prediction | + | `mode` | Mode of predictive distribution | + | `median` | Median prediction | + | `quantiles` | Quantile predictions | + | `main` | Main summary statistics | + | `full` | Full predictive distribution | + + + +--- + +## Promoting a New Model Version + +When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to `champion` with no configuration change needed. + +```python +# After re-running the registration cell with a new TabPFNWrapper... +client = mlflow.MlflowClient() +versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'") +latest = max(versions, key=lambda v: int(v.version)) +client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version) +print(f"Promoted version {latest.version} to 'champion'") +``` + +--- + +## Next Steps + + + + Explore all supported parameters, estimator options, and output types. + + + Learn how TabPFN accelerates business transformation on Databricks. + + + Understand the `mlflow.pyfunc.PythonModel` interface used by the wrapper. + + + Configure autoscaling, traffic splitting, and monitoring for your endpoint. + + diff --git a/docs.json b/docs.json index 7b3c6d7..1e53f52 100644 --- a/docs.json +++ b/docs.json @@ -148,6 +148,14 @@ "integrations/sagemaker" ] }, + { + + "group": "Cookbooks", + "icon": "book", + "pages": [ + "cookbooks/mlflow" + ] + }, { "group": "Use Cases", "pages": [ diff --git a/tutorials/mlflow.mdx b/tutorials/mlflow.mdx new file mode 100644 index 0000000..cd532d8 --- /dev/null +++ b/tutorials/mlflow.mdx @@ -0,0 +1,641 @@ +--- +title: "TabPFN with MLflow" +description: "Learn how to wrap TabPFN as an MLflow PythonModel, register it to Unity Catalog, and deploy it to a Mosaic AI Model serving endpoint." +--- + +## Overview + +[TabPFN](https://github.com/priorlabs/tabpfn) is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable. + +The promise: **one registered model** that handles both classification and regression across any tabular dataset. + + + + Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome. + + + Use a single model for classification and regression across any tabular dataset - from a notebook or SQL `ai_query()`. + + + +By the end of this tutorial you will have: + +- Wrapped TabPFN as an MLflow `PythonModel` +- Registered it to Unity Catalog with a `champion` alias +- Tested it locally on classification and regression tasks +- Run an end-to-end example on the Lending Club loan dataset +- Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint + + + For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the [TabPFN GitHub repository](https://github.com/priorlabs/tabpfn). For broader context on TabPFN and Databricks, read the [Databricks blog post](https://www.databricks.com/blog/tappfn-ai-accelerates-business-transformation-databricks). + + +--- + +## Prerequisites + +- A Databricks workspace with Unity Catalog enabled +- Access to a **serverless GPU environment** with at least an **NVIDIA A10G** GPU +- A TabPFN token from [Prior Labs](https://github.com/priorlabs/tabpfn?tab=readme-ov-file#installation--setup) + + + TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select **Serverless** with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically. + + +--- + +## Step 1: Install Dependencies + +Run the following in your Databricks notebook cell: + +```python +%pip install mlflow tabpfn +dbutils.library.restartPython() +``` + +--- + +## Step 2: Configure Your TabPFN Token + +TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook. + +Run these commands in the **Databricks CLI**: + +```bash +databricks secrets create-scope tabpfn +databricks secrets put-secret tabpfn tabpfn_token --string-value "" +``` + +Then reference the secret in your notebook: + +```python +import os + +os.environ["TABPFN_TOKEN"] = dbutils.secrets.get(scope="tabpfn", key="tabpfn_token") +``` + +--- + +## Step 3: Define the Wrapper and Signature + +The `TabPFNWrapper` is a custom `mlflow.pyfunc.PythonModel`. It is the heart of this integration and handles three concerns: + +- **Dual-format input** - accepts both raw Python objects (from notebooks) and JSON strings (from SQL `ai_query()`) +- **Task routing** - classification or regression, controlled by `task_config` +- **Flexible output** - class labels, probabilities, or regression predictions based on `output_type` + +All input columns use `DataType.string` so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The `_maybe_parse_json()` helper transparently handles both formats at predict time. + +```python +import os +import json +import inspect + +import mlflow +import numpy as np +import pandas as pd + +from typing import Literal + +from mlflow.models.signature import ModelSignature +from mlflow.types.schema import ( + Array, + ColSpec, + DataType, + Object, + ParamSchema, + ParamSpec, + Property, + Schema +) +from tabpfn import TabPFNClassifier, TabPFNRegressor + + +class TabPFNWrapper(mlflow.pyfunc.PythonModel): + """MLflow PythonModel wrapper for TabPFN""" + + _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"} + """The model output types for classification.""" + + _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"} + """The model output types for regression.""" + + @staticmethod + def _maybe_parse_json(value: str | dict | list) -> dict | list: + """Helper function to parse a JSON string if needed. + + Args: + value: The value to parse. Can be a string, dict, or list. + + Returns: + The parsed value. + """ + if isinstance(value, str): + return json.loads(value) + return value + + def _get_output_type( + self, + task: Literal["classification", "regression"], + output_type: str | None + ) -> str: + """Get the prediction output type. + + Args: + output_type: The output type to get. + + Returns: + The prediction output type. + """ + if output_type is not None: + supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES + if output_type not in supported_output_types: + raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}") + return output_type + + # Fallback to the defaults + return "preds" if task == "classification" else "mean" + + def _init_estimator( + self, + task: Literal["classification", "regression"], + config: dict + ) -> TabPFNClassifier | TabPFNRegressor: + """Initialize a TabPFN estimator. + + Args: + task: The task to initialize the estimator for. + config: The configuration for the estimator. + + Returns: + The initialized estimator. + """ + Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor + + # Validate provided config keys against the estimator constructor + sig = inspect.signature(Estimator.__init__) + + constructor_params = set(sig.parameters.keys()) - {"self"} + supplied_keys = set(config.keys()) + invalid_keys = supplied_keys - constructor_params + + # Raise an error if any invalid keys are provided + if invalid_keys: + msg = ( + f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n" + f"Allowed parameters: {sorted(constructor_params)}" + ) + raise ValueError(msg) + + return Estimator(**config) + + def predict(self, model_input, params=None): + """Run predictions. + + TabPFN runs predictions in a single forward pass. The model is + fitted on the training data and then used to predict on the test data. + + Args: + model_input: The input data to predict on. + params: The parameters for the model. + + Returns: + The predictions. + """ + # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint) + if isinstance(model_input, pd.DataFrame): + model_input = model_input.to_dict(orient="records") + + assert isinstance(model_input, list), "model_input must be a list with a single row" + assert len(model_input) == 1, "model_input must have a single row" + + model_input: dict = model_input[0] + + # Parse the task configuration + task_config = self._maybe_parse_json(model_input["task_config"]) or {} + + task = task_config.get("task") + if task is None: + raise KeyError("Task is required, must be 'classification' or 'regression'") + + tabpfn_config = task_config.get("tabpfn_config") or {} + predict_params = task_config.get("predict_params") or {} + + output_type = self._get_output_type(task, predict_params.get("output_type")) + + # Parse the input data + X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64) + y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64) + X_test = np.array(self._maybe_parse_json(model_input["X_test"]), dtype=np.float64) + + # Initialize the estimator + estimator = self._init_estimator(task, tabpfn_config) + + # Fit the estimator + estimator.fit(X_train, y_train) + + # Run predictions + if task == "classification": + if output_type == "probas": + return estimator.predict_proba(X_test).tolist() + return estimator.predict(X_test).tolist() + + predictions = estimator.predict(X_test, output_type=output_type) + return predictions.tolist() +``` + +### Define the Model Signature + +The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are `DataType.string` to support both raw Python and JSON-serialized inputs from SQL `ai_query()`. + +```python +# All inputs are STRING for dual-format support: +# - JSON strings from SQL ai_query() via to_json() +# - Raw Python objects, such as `pd.DataFrame`, from notebook calls + +input_schema = Schema([ + ColSpec(DataType.string, "task_config"), + ColSpec(DataType.string, "X_train"), + ColSpec(DataType.string, "y_train"), + ColSpec(DataType.string, "X_test"), +]) + +# Output schema is required by Unity Catalog. +# STRING is the best fit - the wrapper returns variable shapes +# (flat list for preds, nested list for probas) and the serving +# layer serializes whatever .tolist() produces into a JSON string. +output_schema = Schema([ColSpec(DataType.string, name="predictions")]) +signature = ModelSignature(inputs=input_schema, outputs=output_schema) +``` + + + The `output_schema` is **required** for Unity Catalog registration. Using `DataType.string` is the right choice here because the wrapper can return either a flat list (for `preds`) or a nested list (for `probas`/regression quantiles), and the serving layer serializes the result to JSON regardless. + + +--- + +## Step 4: Register to Unity Catalog + +Log the model with MLflow and register it under a fully qualified Unity Catalog path (`catalog.schema.tabpfn`). The `input_example` uses `json.dumps()` to match the all-string signature - the wrapper deserializes at predict time via `_maybe_parse_json()`. + +After registration, we tag the latest version with a `"champion"` alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config. + +```python +# Input example uses json.dumps() to match the all-string signature +# The wrapper's _maybe_parse_json() deserializes these at predict time +input_example = pd.DataFrame([{ + "task_config": json.dumps({ + "task": "classification", + "tabpfn_config": { + "n_estimators": 8, + "softmax_temperature": 0.9, + }, + "predict_params": { + "output_type": "preds", + }, + }), + "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]), + "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]), + "X_test": json.dumps([[2.0, 3.0, 0.0]]), +}]) + +# Fully qualified Unity Catalog path (portable across workspaces) +CATALOG = spark.catalog.currentCatalog() +SCHEMA = spark.catalog.currentDatabase() +REGISTERED_MODEL_NAME = f"{CATALOG}.{SCHEMA}.tabpfn" + +with mlflow.start_run(run_name="tabpfn-registration") as run: + model_info = mlflow.pyfunc.log_model( + name="tabpfn", + python_model=TabPFNWrapper(), + signature=signature, + input_example=input_example, + pip_requirements=["tabpfn", "numpy", "pandas"], + registered_model_name=REGISTERED_MODEL_NAME, + ) + print(f"Model URI: {model_info.model_uri}") + print(f"Run ID: {run.info.run_id}") +``` + +Once logged, tag the latest version as `champion`: + +```python +# Tag the latest version with a "champion" alias for serving +client = mlflow.MlflowClient() +versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'") +latest = max(versions, key=lambda v: int(v.version)) +client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version) +print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}") +``` + + + Using the `champion` alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change. + + +--- + +## Step 5: Test Locally + +Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call `.predict()`. + + + + ```python + # Load the registered model from MLflow + loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + + # DataFrame with raw Python objects, the notebook-friendly example + predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": {"task": "classification"}, + "X_train": [ + [1.0, 2.0, 0.0], + [3.0, 4.0, 1.0], + [5.0, 6.0, 0.0], + [7.0, 8.0, 1.0] + ], + "y_train": [0.0, 1.0, 0.0, 1.0], + "X_test": [ + [2.0, 3.0, 0.0] + ], + }])) + predictions + ``` + + + ```python + loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + + predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": json.dumps({"task": "classification", "predict_params": {"output_type": "probas"}}), + "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]), + "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]), + "X_test": json.dumps([[2.0, 3.0, 0.0]]), + }])) + predictions + ``` + + + ```python + loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + + predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": {"task": "regression", "predict_params": {"output_type": "mean"}}, + "X_train": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + "y_train": [1.5, 3.5, 5.5, 7.5], + "X_test": [[2.0, 3.0]], + }])) + predictions + ``` + + + +The wrapper's `_maybe_parse_json()` transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint. + +--- + +## Step 6: End-to-End Example - Lending Club Loan Data + +Let's run a real-world classification task: predicting loan default (Good vs Bad) on the [Lending Club Q2 2018](https://www.kaggle.com/datasets/wordsforthewise/lending-club) dataset. This dataset ships with every Databricks workspace at `/databricks-datasets/`, so you can run this without any additional data download. + +### Load and Prepare the Data + +```python +from sklearn.model_selection import train_test_split + +# Load Lending Club Q2 2018 (ships with every Databricks workspace) +df = ( + spark.read.csv( + "/databricks-datasets/lending-club-loan-stats/LoanStats_2018Q2.csv", + header=True, inferSchema=True, + ) + .select( + "loan_status", + "loan_amnt", "funded_amnt", "installment", "annual_inc", "dti", + "open_acc", "revol_bal", "total_acc", "delinq_2yrs", "inq_last_6mths", + "pub_rec", "mort_acc", "tot_cur_bal", "total_pymnt", "last_pymnt_amnt", + ) + .dropna(subset=["loan_status"]) + .toPandas() +) + +# Binary target: Good (0) vs Bad (1) +df["target"] = (df["loan_status"].apply( + lambda s: 0 if s in ("Current", "Fully Paid") else 1 +).astype(int)) +df = df.drop(columns=["loan_status"]).dropna() + +# Sample and split +df_sample = df.sample(n=6_000, random_state=42) +X_train, X_test, y_train, y_test = train_test_split( + df_sample.drop(columns=["target"]), df_sample["target"], + test_size=0.2, random_state=42, stratify=df_sample["target"], +) + +print(f"Train: {len(X_train):,} × {X_train.shape[1]} (bad-loan rate: {y_train.mean():.2%})") +print(f"Test: {len(X_test):,} × {X_test.shape[1]}") +``` + +### Run Predictions + +Pass the full dataset through the registered MLflow model in a single call. We request `probas` so we can compute ROC-AUC alongside accuracy and F1. + +```python +# Predict via the registered MLflow model +predictions = loaded_model.predict(pd.DataFrame([{ + "task_config": { + "task": "classification", + "predict_params": {"output_type": "probas"}, + }, + "X_train": X_train.values.tolist(), + "y_train": y_train.values.tolist(), + "X_test": X_test.values.tolist(), +}])) + +probas = np.array(predictions) +y_pred = probas.argmax(axis=1) + +print(f"Predicted {len(y_pred):,} samples") +``` + +### Evaluate Results + +```python +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score + +print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}") +print(f"F1 (weighted): {f1_score(y_test, y_pred, average='weighted'):.4f}") +print(f"ROC-AUC: {roc_auc_score(y_test, probas[:, 1]):.4f}") + +# Materialize as a Spark DataFrame +results_sdf = spark.createDataFrame(pd.DataFrame({ + "actual": y_test.values, + "predicted": y_pred, + "probability_bad_loan": np.round(probas[:, 1], 4), +})) + +display(results_sdf) +``` + +No fine-tuning, no feature engineering pipeline, no hyperparameter search - TabPFN fits and predicts in a single forward pass. + +--- + +## Step 7: Deploy to Mosaic AI Model Serving + +Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration. + +```python +import mlflow.deployments + +# Resolve the "champion" alias to a version number +client = mlflow.MlflowClient() +champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion") +print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}") + +# Get the deployment MLflow client +client = mlflow.deployments.get_deploy_client("databricks") + +# Create the endpoint, will return immediately and continue initializing the endpoint async +# Check for the status in your Databricks console +endpoint = client.create_endpoint( + name="tabpfn-endpoint", + config={ + "served_entities": [{ + "entity_name": REGISTERED_MODEL_NAME, + "entity_version": str(champion.version), + "workload_size": "Medium", + "workload_type": "GPU_MEDIUM", + "scale_to_zero_enabled": True, + "environment_vars": { + "HF_TOKEN": "{{secrets/tabpfn/hf_token}}", + }, + }], + }, +) + +print(f"Endpoint created: {endpoint['name']}") +``` + + + `create_endpoint` returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under **Serving** → **tabpfn-endpoint**. + + +### Calling the Endpoint + +Once the endpoint is live, you can reach it from Python, REST, or SQL: + + + + ```python + import mlflow.deployments + + client = mlflow.deployments.get_deploy_client("databricks") + response = client.predict( + endpoint="tabpfn-endpoint", + inputs={"dataframe_records": [{ + "task_config": json.dumps({"task": "classification"}), + "X_train": json.dumps([[1.0, 2.0], [3.0, 4.0]]), + "y_train": json.dumps([0.0, 1.0]), + "X_test": json.dumps([[2.0, 3.0]]), + }]}, + ) + print(response) + ``` + + + ```bash + curl -X POST \ + https:///serving-endpoints/tabpfn-endpoint/invocations \ + -H "Authorization: Bearer $DATABRICKS_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "dataframe_records": [{ + "task_config": "{\"task\": \"classification\"}", + "X_train": "[[1.0, 2.0], [3.0, 4.0]]", + "y_train": "[0.0, 1.0]", + "X_test": "[[2.0, 3.0]]" + }] + }' + ``` + + + ```sql + SELECT ai_query( + 'tabpfn-endpoint', + named_struct( + 'task_config', to_json(named_struct('task', 'classification')), + 'X_train', to_json(array(array(1.0, 2.0), array(3.0, 4.0))), + 'y_train', to_json(array(0.0, 1.0)), + 'X_test', to_json(array(array(2.0, 3.0))) + ) + ) AS prediction + ``` + + + +--- + +## How the Input Format Works + +Understanding `task_config` is key to using the endpoint effectively. It controls both what TabPFN does and how it does it. + +| Field | Type | Description | +|---|---|---| +| `task` | `string` | Required. `"classification"` or `"regression"` | +| `tabpfn_config` | `object` | Optional. Passed directly to `TabPFNClassifier` / `TabPFNRegressor` constructor (e.g. `n_estimators`, `softmax_temperature`) | +| `predict_params` | `object` | Optional. Controls output format via `output_type` | + +### Output Types + + + + | `output_type` | Description | + |---|---| + | `preds` (default) | Predicted class labels | + | `probas` | Class probabilities (nested list) | + + + | `output_type` | Description | + |---|---| + | `mean` (default) | Mean prediction | + | `mode` | Mode of predictive distribution | + | `median` | Median prediction | + | `quantiles` | Quantile predictions | + | `main` | Main summary statistics | + | `full` | Full predictive distribution | + + + +--- + +## Promoting a New Model Version + +When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to `champion` with no configuration change needed. + +```python +# After re-running the registration cell with a new TabPFNWrapper... +client = mlflow.MlflowClient() +versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'") +latest = max(versions, key=lambda v: int(v.version)) +client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version) +print(f"Promoted version {latest.version} to 'champion'") +``` + +--- + +## Next Steps + + + + Explore all supported parameters, estimator options, and output types. + + + Learn how TabPFN accelerates business transformation on Databricks. + + + Understand the `mlflow.pyfunc.PythonModel` interface used by the wrapper. + + + Configure autoscaling, traffic splitting, and monitoring for your endpoint. + +