ShieldFactory & ShieldConfig
A shield introduces per-state constraints: for every state in an MDP or SMG it records which actions are safe enough to take, expressed as a bitmask.
A ShieldFactory builds the necessary model and computes the probabilities to stay safe for each state and action pair.
It does so by running probabilistic model checking on a PRISM model.
To build a concrete shield, you specify the safety threshold via the ShieldConfig.
Quick start
from tempestpy.shielding import ShieldFactory, ShieldConfig
factory = ShieldFactory(
"model.prism",
property="Pmax=? [G safe]",
constants={"N": 5, "slippery": False},
)
config = ShieldConfig(threshold=0.9)
result = factory.build(config)
factory.build(config) runs model checking once, caches the raw choice values, and returns an ExplicitShieldResult (or SymbolicShieldResult for engine="dd").
ShieldFactory
Constructor
| Parameter | Type | Description |
|---|---|---|
model_path |
str |
Path to a .prism file. |
property |
str |
PCTL/PRCTL property, e.g. "Pmax=? [G safe]". |
constants |
dict |
PRISM constants, e.g. {"N": 5, "slippery": False}. |
check_result |
stormpy result | Pre-computed check result — skips model checking. |
engine |
"sparse" | "dd" |
Backend. "sparse" (default) for explicit-state; "dd" for symbolic. |
Using a Different Property
To use a different property, you can copy the factory:
# Switch to a different safety objective without rebuilding the model:
factory.prop = "Pmin=? [F crash]"
result2 = factory.build(ShieldConfig(threshold=0.05, comparison="absolute"))
or use with_prop:
factory2 = factory.with_prop("Pmin=? [F crash]")
Querying state information
# Map PRISM variable values to a state id
sid = factory.get_state_id({"x": 3, "y": 1, "done": False})
# Get the raw model-checked choice values for that state
pvals = factory.get_choice_values_for_state(sid) # np.ndarray, one entry per action
Saving and loading
# Slim save: only choice values and state lookup (~fastest, smallest file)
factory.save("shield.pkl")
# Full save: also embeds the DRN-serialised model (factory is fully self-contained)
factory.save("shield_full.pkl", slim=False)
# Load back
factory2 = ShieldFactory.load("shield.pkl")
result2 = factory2.compute(config) # reuse cached values; no model checking needed
Slim vs full
slim=True (default) saves only the pre-computed choice values and the state lookup table.
The loaded factory can call compute() but not build() (the model is absent).
Use slim=False when you need the loaded factory to be able to re-run model checking.
ShieldConfig
A ShieldConfig specifies how a concrete shield should be built using a ShieldFactory.
from tempestpy.shielding import ShieldConfig
config = ShieldConfig(
threshold=0.9, # default: 1.0
comparison="relative", # default: "absolute"
)
Threshold and comparison mode
absolute (default): action i is allowed when its model-checked value ≥ threshold (for Pmax) or ≤ threshold (for Pmin).
# Allow any action achieving at least 80 % safety probability
config = ShieldConfig(threshold=0.8, comparison="absolute")
relative: threshold is scaled by the per-state best value before comparison.
An action is allowed when its value ≥ threshold × best_value_in_state.
# Allow actions within 5 % of the best achievable value in each state
config = ShieldConfig(threshold=0.95, comparison="relative")
post_selector
A post_selector is only needed for post-shielding.
It is a callable (state_id, blocked_action) -> replacement_action invoked when the agent picks a blocked action.
See the Wrappers page for ready-made selectors.
action_lookup
By default bitmask bit i corresponds to local action index i in the PRISM model.
When the RL environment uses a different global action numbering, supply an ActionLookup:
from tempestpy.shielding.action_lookup import ActionLookup
lookup = ActionLookup({"left": 0, "right": 1, "up": 2, "down": 3})
config = ShieldConfig(threshold=0.9, action_lookup=lookup)
Bit i in the resulting bitmask then corresponds to global RL action index i.
ShieldResult
build / compute return a ShieldResult subclass.
For the sparse engine that is ExplicitShieldResult (states are integer ids);
for engine="dd" it is SymbolicShieldResult (states are PRISM variable dicts).
Querying the mask
# Pre-shielding: get the bitmask for a state (falls back when nothing passes)
bits = result.query_mask(sid) # int; bit i = action i is allowed
# Post-shielding: pass agent's chosen action through the shield
safe = result.query_post(sid, agent_action) # returns agent_action or selector result
# Check a single action
allowed = result.is_action_allowed(sid, action_index) # bool