Wrappers
tempestpy ships with two Gymnasium wrappers that connect a compiled shield to a reinforcement-learning training loop.
Both wrappers require gymnasium to be installed:
pip install gymnasium
Concepts: pre-shielding vs post-shielding
| Pre-shielding | Post-shielding | |
|---|---|---|
| When | Before the agent acts — the policy only sees allowed actions | After the agent acts — the policy may propose any action |
| Mechanism | Injects a boolean action_mask into info |
Silently replaces blocked actions before env.step |
| Good for | Constrained-policy training (e.g. SB3 MaskablePPO) |
Wrapping a policy that cannot consume masks |
| Wrapper | PreShieldWrapper |
PostShieldWrapper |
obs_to_values
Both wrappers require an obs_to_values callable that maps a Gymnasium observation (and optionally info) to a dict of PRISM variable values.
This is the bridge between the RL observation space and the PRISM state space.
# Example: observation is a flat numpy array [x, y, has_key]
def obs_to_values(obs, info=None):
return {"x": int(obs[0]), "y": int(obs[1]), "has_key": bool(obs[2])}
The returned dict must match the variable names in the PRISM model exactly.
PreShieldWrapper
PreShieldWrapper runs the shield on every reset() and step() and writes the result into info:
| Key | Type | Description |
|---|---|---|
"action_mask" |
np.ndarray[bool] |
Shape (n_actions,) — True = allowed |
"shield_state_id" |
int |
The model state id used for this mask |
"shield_bitmask" |
int |
Raw packed bitmask |
Setup
import gymnasium as gym
from tempestpy.shielding import ShieldFactory, ShieldConfig, PreShieldWrapper
factory = ShieldFactory("model.prism", property="Pmax=? [G safe]")
config = ShieldConfig(threshold=0.9)
env = gym.make("MyEnv-v0")
env = PreShieldWrapper(
env,
factory,
config,
obs_to_values=obs_to_values,
)
The shield is computed once during __init__; subsequent reset() / step() calls query the shield.
Using Stable-Baselines 3 MaskablePPO
PreShieldWrapper implements the action_masks() hook expected by SB3's MaskablePPO:
from sb3_contrib import MaskablePPO
model = MaskablePPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100_000)
SB3 calls env.action_masks() before each forward pass. No extra integration code is needed to enforce the shield.
Changing the threshold mid-training
# E.g.: tighten the shield after a curriculum phase
env.rebuild(ShieldConfig(threshold=0.99))
PostShieldWrapper
PostShieldWrapper lets the agent act freely; if it picks a blocked action the wrapper silently replaces it before calling env.step.
The info dict contains:
| Key | Type | Description |
|---|---|---|
"shield_state_id" |
int |
State id for the current step |
"shield_safe_action" |
int |
The action actually sent to env.step |
"shield_corrected" |
bool |
True when the agent's action was replaced |
Setup
A post_selector is mandatory — it decides the replacement when the agent's choice is blocked.
import gymnasium as gym
from tempestpy.shielding import (
ShieldFactory, ShieldConfig,
PostShieldWrapper,
make_best_value_selector,
)
factory = ShieldFactory("model.prism", property="Pmax=? [G safe]")
result = factory.build(ShieldConfig(threshold=0.9)) # build first to get result
selector = make_best_value_selector(factory, result)
config = ShieldConfig(threshold=0.9, post_selector=selector)
env = gym.make("MyEnv-v0")
env = PostShieldWrapper(
env,
factory,
config,
obs_to_values=obs_to_values,
)
Post-selector factory functions
These helpers create ready-made post_selector callables for ShieldConfig.
make_best_value_selector
Replaces a blocked action with the action that has the highest (or lowest, for minimise) model-checked value in the current state.
from tempestpy.shielding import make_best_value_selector
selector = make_best_value_selector(factory, result, is_minimize=False)
config = ShieldConfig(threshold=0.9, post_selector=selector)
Use this when you want the fallback to be as safe as possible according to the model.
make_random_safe_selector
Picks uniformly at random from the allowed actions. In dangerous states (no action meets the threshold) the shield's fallback bitmask is used instead.
import numpy as np
from tempestpy.shielding import make_random_safe_selector
selector = make_random_safe_selector(result, rng=np.random.default_rng(42))
config = ShieldConfig(threshold=0.9, post_selector=selector)
Use this to add randomness to the correction, which can help exploration.