Learning how to create models with yml files#

The following notebook will teach you to create pymc-marketing models from yml files, allowing you to easily recreate your models in production environments without several lines of code.

Setup#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
from pymc_marketing.paths import data_dir

warnings.filterwarnings("ignore")

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
/Users/carlostrujillo/Documents/GitHub/pymc-marketing/pymc_marketing/mmm/multidimensional.py:216: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
  warnings.warn(warning_msg, FutureWarning, stacklevel=1)
/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/ipykernel_49626/1583561548.py:7: UserWarning: The pymc_marketing.mmm.builders module is experimental and its API may change without warning.
  from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
X = pd.read_csv(data_dir / "processed" / "X.csv")
y = pd.read_csv(data_dir / "processed" / "y.csv")
X.head(3)
date market channel_1 channel_2
0 2023-01-01 US 70.171496 20.945956
1 2023-01-02 US 90.243918 45.828916
2 2023-01-03 US 9.178717 26.322735
y.head(3)
y
0 45.453806
1 42.516346
2 54.250939

Multidimensional model#

mmm = build_mmm_from_yaml(
    X=X, y=y, config_path=data_dir / "config_files" / "multi_dimensional_model.yml"
)
mmm.model.to_graphviz()
../../_images/f4b860449b47cadcccfa787a6b2b0d3e38042c0dfb4422bd8e8b9f5d064cb78e.svg
prior_predictive = mmm.sample_prior_predictive(X=X, y=y, samples=1_000)
Sampling: [adstock_alpha, intercept_contribution, saturation_alpha, saturation_lam, y, y_sigma]
prior_predictive
<xarray.Dataset> Size: 2MB
Dimensions:  (date: 100, market: 2, sample: 1000)
Coordinates:
  * date     (date) datetime64[ns] 800B 2023-01-01 2023-01-02 ... 2023-04-10
  * market   (market) <U2 16B 'EU' 'US'
  * sample   (sample) object 8kB MultiIndex
  * chain    (sample) int64 8kB 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0
  * draw     (sample) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
    y        (date, market, sample) float64 2MB -0.8231 0.1999 ... -1.278 0.6622
Attributes:
    created_at:                 2025-10-07T09:24:53.031443+00:00
    arviz_version:              0.22.0
    inference_library:          pymc
    inference_library_version:  5.25.1
    pymc_marketing_version:     0.16.0
mmm.fit(
    X=X,
    y=y.y,
    random_seed=42,
)

mmm.sample_posterior_predictive(
    X=X,
    extend_idata=True,
    combined=True,
    random_seed=42,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_alpha, saturation_lam, y_sigma]

Sampling 8 chains for 1_000 tune and 200 draw iterations (8_000 + 1_600 draws total) took 14 seconds.
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details

Sampling: [y]

<xarray.Dataset> Size: 3MB
Dimensions:  (date: 100, market: 2, sample: 1600)
Coordinates:
  * date     (date) datetime64[ns] 800B 2023-01-01 2023-01-02 ... 2023-04-10
  * market   (market) <U2 16B 'EU' 'US'
  * sample   (sample) object 13kB MultiIndex
  * chain    (sample) int64 13kB 0 0 0 0 0 0 0 0 0 0 0 ... 7 7 7 7 7 7 7 7 7 7 7
  * draw     (sample) int64 13kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199
Data variables:
    y        (date, market, sample) float64 3MB 0.5852 0.5207 ... 0.595 0.5213
Attributes:
    created_at:                 2025-10-07T09:25:19.061570+00:00
    arviz_version:              0.22.0
    inference_library:          pymc
    inference_library_version:  5.25.1

How the config works?#

# Let's look at the content of the basic model configuration file
with open(data_dir / "config_files" / "basic_model.yml") as f:
    basic_config = f.read()

print(basic_config)
model:
  class: pymc_marketing.mmm.multidimensional.MMM
  kwargs:
    date_column: "date"
    channel_columns:                                     # explicit for reproducibility
      - channel_1
      - channel_2
      # …
    target_column: "y"

    # --- media transformations ---------------------------------------
    adstock:
      class: pymc_marketing.mmm.GeometricAdstock
      kwargs: {l_max: 12}        # any other hyper-parameters here

    saturation:
      class: pymc_marketing.mmm.MichaelisMentenSaturation
      kwargs: {}                 # default α, λ priors inside the class

# ----------------------------------------------------------------------
# (optional) sampler options you plan to forward to pm.sample():
    sampler_config:
      tune: 1000
      draws: 200
      chains: 8
      random_seed: 42
      target_accept: 0.90
      # nuts_sampler: "nutpie"

# ----------------------------------------------------------------------
# (optional) idata from a previous sample
# idata_path: "data/idata.nc"

# ----------------------------------------------------------------------
# (optional) Data paths
# data:
#   X_path: "data/X.csv"
#   y_path: "data/y.csv"

The configuration file uses a structured YAML format with several key sections:

  1. schema_version: Version identifier for the configuration schema

  2. model: The main model configuration

    • class: The Python class to instantiate (fully qualified name)

    • kwargs: Arguments passed to the model constructor

      • Including data columns, transformations (adstock, saturation)

  3. sample_kwargs: Optional parameters for the sampling process

  4. data: Optional paths to data files

The build_mmm_from_yaml function:

  • Parses this YAML configuration

  • Uses the ‘build’ function to instantiate objects recursively

  • Handles special cases like priors and distributions

  • Returns a fully configured MMM model ready for sampling

  • If idata_path is provided then the idata from a previous class is used in the model in the idata property.

Basic model#

mmm2 = build_mmm_from_yaml(
    X=X, y=y, config_path=data_dir / "config_files" / "basic_model.yml"
)
mmm2.model.to_graphviz()
../../_images/8c6bfeec12b79556da7f34d84c67c3d0af84b16d6555a405aff6841060b23413.svg
prior_predictive = mmm2.sample_prior_predictive(X=X, y=y, samples=1_000)
Sampling: [adstock_alpha, intercept_contribution, saturation_alpha, saturation_lam, y, y_sigma]

Multidimensional Hierarchical Model#

mmm3 = build_mmm_from_yaml(
    X=X,
    y=y,
    config_path=data_dir / "config_files" / "multi_dimensional_hierarchical_model.yml",
)
mmm3.model.to_graphviz()
../../_images/c6c433b6700b1fdc43084b26d24701eb692707ef8237c902e86af37e2f4eb300.svg
prior_predictive = mmm3.sample_prior_predictive(X=X, y=y, samples=1_000)
Sampling: [adstock_alpha, intercept_contribution, intercept_contribution_beta, saturation_beta, saturation_lam, saturation_lam_alpha, y, y_sigma]

Multidimensional Hierarchical with arbitrary effects and calibration#

data_dir / "config_files" / "multi_dimensional_hierarchical_with_arbitrary_effects_model.yml"
PosixPath('/Users/carlostrujillo/Documents/GitHub/pymc-marketing/data/config_files/multi_dimensional_hierarchical_with_arbitrary_effects_model.yml')
mmm4 = build_mmm_from_yaml(
    X=X,
    y=y,
    config_path=data_dir
    / "config_files"
    / "multi_dimensional_hierarchical_with_arbitrary_effects_model.yml",
)
mmm4.model.to_graphviz()
../../_images/861a802294fe7dfdb5cb6f83bd69a8ca2dce89414c245fb29d172f1226c74342.svg
prior_predictive = mmm4.sample_prior_predictive(X=X, y=y, samples=1_000)
Sampling: [adstock_alpha, delta, delta_mu, example_lift_tests, intercept_contribution, intercept_contribution_beta, saturation_beta, saturation_lam, saturation_lam_alpha, weekly_fourier_beta, weekly_fourier_beta_mu, y, y_sigma]
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor
Last updated: Tue Oct 07 2025

Python implementation: CPython
Python version       : 3.12.11
IPython version      : 9.4.0

pymc_marketing: 0.16.0
pytensor      : 2.31.7

pymc_marketing: 0.16.0
matplotlib    : 3.10.3
arviz         : 0.22.0
pandas        : 2.3.1

Watermark: 2.5.0