API Reference
tempestpy.shielding.ShieldFactory(model_path: str, *, property: Optional[str] = None, constants: Optional[Dict[str, Any]] = None, check_result: Optional[Any] = None, engine: str = 'sparse')
Entry point for computing probabilistic shields from PRISM models.
A ShieldFactory parses a PRISM MDP/SMG, runs model
checking to obtain per-action choice values, and converts those values
into ShieldResult bitmasks via build or compute.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Path string to a |
required | |
property
|
Optional[str]
|
A PCTL/PRCTL property string, e.g. |
None
|
constants
|
Optional[Dict[str, Any]]
|
PRISM constant definitions, e.g. |
None
|
check_result
|
Optional[Any]
|
A pre-computed stormpy check result that already contains choice values. When provided, model checking is skipped entirely. |
None
|
engine
|
str
|
Backend representation. |
'sparse'
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
TypeError
|
If the model type or engine/model combination is unsupported. |
Examples:
>>> factory = ShieldFactory("model.prism", property="Pmax=? [G safe]",
... constants={"N": 5})
>>> result = factory.build(ShieldConfig(threshold=0.9))
model: Any
property
The stormpy model held by this factory.
prism_program: Any
property
The parsed storm::prism::Program held by this factory.
prop: Optional[str]
property
writable
The active PCTL property string, or None if not set.
formula: Any
property
The parsed stormpy formula, or None before a property is set.
optimality_type: Any
property
The optimality direction (Minimize / Maximize) derived from
the formula, or None if no formula has been set.
is_sparse: bool
property
True when the factory operates in sparse-engine mode.
state_lookup: StateValuationLookup
property
Lazy-initialised bidirectional map between state ids and PRISM valuations.
with_prop(prop: str) -> 'ShieldFactory'
Return a copy of this factory with a different property string.
build(config: ShieldConfig) -> Any
Run model checking (if needed) and compute the ShieldResult.
If model checking has not been performed yet, it is triggered here. Subsequent calls with the same property reuse the cached result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ShieldConfig
|
Threshold, comparison mode, and optional post-selector / action lookup settings for the ShieldResult computation. |
required |
Returns:
| Type | Description |
|---|---|
ShieldResult
|
Per-state bitmasks encoding which actions are permitted. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If neither a formula nor a check result is available. |
compute(config: ShieldConfig) -> Any
Compute the ShieldResult from an already-available check result.
Unlike build, this method does not run model checking.
A check result must be present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ShieldConfig
|
Threshold, comparison mode, and optional post-selector / action lookup settings. |
required |
Returns:
| Type | Description |
|---|---|
ShieldResult
|
Per-state bitmasks encoding which actions are permitted. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If no check result is available. |
get_choice_values_for_state(sid: int) -> np.ndarray
Return the model-checked choice values for all actions in state sid.
get_state_id(values: dict) -> int
Look up the state id for a given PRISM variable valuation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values
|
dict
|
Mapping of PRISM variable name to value,
e.g. |
required |
Returns:
| Type | Description |
|---|---|
int
|
The model state id corresponding to |
Raises:
| Type | Description |
|---|---|
KeyError
|
If the valuation does not match any state. |
dangerous_states(*, shield_result: ShieldResult) -> list[dict]
Return the PRISM valuations of all dangerous states in shield_result.
A state is dangerous when no action meets the shield threshold and the fallback (best-value) action is used instead.
critical_states(*, shield_result: ShieldResult) -> list[dict]
Return the PRISM valuations of all critical states in shield_result.
A state is critical when at least one action is blocked by the threshold but at least one action still passes.
safe_states(*, shield_result: ShieldResult) -> list[dict]
Return the PRISM valuations of all safe states in shield_result.
A state is safe when every action passes the shield threshold.
states_by_class(*, shield_result: ShieldResult) -> dict[str, list[dict]]
Partition all states into the four shield classes with their valuations.
save(path: str, *, slim: bool = True) -> None
Persist this factory to a pickle file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Destination file path. |
required |
slim
|
bool
|
|
True
|
load(path: str) -> 'ShieldFactory'
classmethod
Load a ShieldFactory from a pickle file produced by save.
Notes
If saved with slim=True: only choice values and state lookup are
available; the Storm model is not restored.
If saved with slim=False: the Storm model is restored from DRN;
call build() to re-run model checking if needed.
to_bitmask(shield_result: Any) -> dict[int, Bitmask]
Convert shield_result to a {state_id: bitmask} dictionary.
to_allowed_action_indices(shield_result: Any) -> dict[int, list[int]]
Convert shield_result to a {state_id: [action_index, ...]} dictionary.
to_allowed_action_labels(shield_result: Any) -> dict[int, list]
Convert shield_result to a {state_id: [frozenset(labels), ...]} dictionary.
to_allowed_state_action_pairs(shield_result: Any) -> list[tuple[int, int]]
Return all allowed (state_id, local_action_index) pairs.
to_valuation_bitmask(shield_result: Any) -> dict[dict, Bitmask]
Convert shield_result to a {valuation_key: bitmask} dictionary.
to_valuation_allowed_action_labels(shield_result: Any) -> dict
Convert shield_result to a {valuation_key: [frozenset(labels), ...]} dictionary.
to_valuation_action_label_probability() -> dict
Map every state to all its actions and their model-checked probabilities: {valuation_key: {frozenset(action_labels): probability}}
pretty(shield_result: Any, *, max_states: int = 25, show_valuations: bool = True, action_ref: str = 'labels') -> str
Return a human-readable summary of shield_result for debugging.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shield_result
|
Any
|
A |
required |
max_states
|
int
|
Maximum number of states to include in the output (default 25). |
25
|
show_valuations
|
bool
|
If |
True
|
action_ref
|
str
|
How to identify actions: |
'labels'
|
Returns:
| Type | Description |
|---|---|
str
|
Multi-line human-readable shield summary. |
to_storm_format(shield_result: Any) -> str
Render shield_result as a Storm-compatible shield text.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shield_result
|
Any
|
A |
required |
Returns:
| Type | Description |
|---|---|
str
|
Multi-line string matching Storm's shield output format, with per-action choice values and valuation annotations. |
tempestpy.shielding.ShieldConfig(threshold: float = 1.0, comparison: ComparisonMode = 'absolute', post_selector: Optional[PostSelectorFn] = None, action_lookup: Optional[ActionLookup] = None)
dataclass
Immutable configuration for a single shield computation.
Attributes:
| Name | Type | Description |
|---|---|---|
threshold |
float
|
The probability threshold. In |
comparison |
ComparisonMode
|
How to interpret threshold: |
post_selector |
Optional[PostSelectorFn]
|
Optional callable invoked by |
action_lookup |
Optional[ActionLookup]
|
Optional |
tempestpy.shielding.ShieldResult(config: Any, bit_width: Optional[int] = None, coalition_states: Optional[frozenset] = None)
dataclass
Bases: ABC
Base class for shield results produced by explicit and symbolic backends.
Explicit shields are keyed by integer state ids. Symbolic shields are
keyed by PRISM variable valuations, for example {"x": 4, "y": 7}.
query_mask(state: Any) -> int
Return the effective pre-shielding bitmask for state.
Bit i is set when action i is permitted. If no action meets the threshold (raw bitmask is 0), the fallback bitmask is returned instead, so the result is always non-zero unless the state is truly unsafe.
query_post(state: Any, agent_action: int, **kwargs: Any) -> int
Return agent_action if it is allowed by the shield, otherwise
delegate to ShieldConfig.post_selector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Current state key (int for explicit, dict for symbolic). |
required |
agent_action
|
int
|
The action index the agent wants to take. |
required |
**kwargs
|
Any
|
Forwarded verbatim to the post-selector (e.g. |
{}
|
Returns:
| Type | Description |
|---|---|
int
|
The action index to actually execute. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If the action is blocked and no |
query_action(state: Any, agent_action: int, **kwargs: Any) -> int
Alias for query_post for callers that prefer action-oriented naming.
is_action_allowed(state: Any, action_index: int) -> bool
Return True when action_index passes the shield threshold at state.
is_coalition_state(sid: int) -> bool
Return True when sid belongs to the shielded coalition.
For plain MDPs (no SMG), always returns True. For SMGs, returns
True only for ego-player states; opponent states are never blocked.
is_unsafe(state: Any) -> bool
Return True when no action is available and no fallback exists.
Both the threshold bitmask and the fallback bitmask are zero, meaning the shield has no safe or even best-effort action to offer.
is_dangerous(state: Any) -> bool
Return True when no action met the threshold and the fallback is active.
The fallback bitmask is non-zero, so the shield can still suggest an action, but none of them satisfied the safety threshold.
is_critical(state: Any) -> bool
Return True when at least one action is blocked but at least one passes.
The threshold bitmask is non-zero and fewer than all actions are permitted, with no fallback active.
is_safe(state: Any) -> bool
Return True when every available action passes the shield threshold.
The threshold bitmask covers all actions and no fallback is active.
classify_state(state: Any) -> str
Classify state into one of four shield categories.
Returns:
| Type | Description |
|---|---|
str
|
|
tempestpy.shielding.ExplicitShieldResult(config: Any, bit_width: Optional[int] = None, coalition_states: Optional[frozenset] = None, bitmask_by_state: list[Bitmask] = list(), fallback_by_state: list[Bitmask] = list(), nr_actions_by_state: list[int] = list(), best_value_by_state: Optional[np.ndarray] = None)
dataclass
Bases: ShieldResult
Shield result produced by the sparse explicit-state backend (engine="sparse").
States are indexed by integer state id. All query methods accept an int
and index into the per-state lists directly.
Attributes:
| Name | Type | Description |
|---|---|---|
bitmask_by_state |
list[Bitmask]
|
Per-state allowed-action bitmask. Bit i is set when action i meets the shield threshold. |
fallback_by_state |
list[Bitmask]
|
Per-state fallback bitmask used when no action meets the threshold (the best available action, or 0 when none exists). |
nr_actions_by_state |
list[int]
|
Total number of actions available in each state (used for
|
best_value_by_state |
Optional[ndarray]
|
Optional array of per-state best model-checked values, filled when the factory computes best-value selectors. |
tempestpy.shielding.SymbolicShieldResult(config: Any, bit_width: Optional[int] = None, coalition_states: Optional[frozenset] = None, bitmask_add: Any = None, fallback_add: Any = None, nr_actions: int = 0)
dataclass
Bases: ShieldResult
Shield result produced by the DD-backed symbolic backend (engine="dd").
States are identified by PRISM variable valuations rather than integer ids,
e.g. {"x": 4, "y": 7}. All query methods accept a dict and
evaluate the underlying algebraic decision diagrams (ADDs) at that point.
Attributes:
| Name | Type | Description |
|---|---|---|
bitmask_add |
Add
|
ADD encoding the allowed-action :data: |
fallback_add |
(Add, optional)
|
ADD encoding the fallback :data: |
nr_actions |
int
|
Uniform action count for the model, used by |
Gymnasium Wrappers
tempestpy.shielding.PreShieldWrapper(env: gym.Env, factory: ShieldFactory, config: ShieldConfig, *, obs_to_values: ObsToValuesFn)
Bases: Wrapper
Gymnasium wrapper that adds pre-shield action masking.
On every reset() and step(), the wrapper queries the shield and
injects an "action_mask" boolean array into info. The mask can
be consumed directly or via the action_masks() hook for SB3
MaskablePPO.
The info dict after reset() / step() contains:
"action_mask"—np.ndarray[bool], shape(n_actions,)"shield_state_id"—int, state id used for the current mask"shield_bitmask"—int, raw packed bitmask
See Also
PostShieldWrapper : Corrects the agent's chosen action after the fact.
ShieldFactory.build : Produces the ShieldResult used internally.
Initialise a PreShieldWrapper.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Env
|
The Gymnasium environment to wrap. Must have a
|
required |
factory
|
ShieldFactory
|
A |
required |
config
|
ShieldConfig
|
|
required |
obs_to_values
|
ObsToValuesFn
|
Callable |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
action_masks() -> np.ndarray
Return the current boolean action mask.
Returns:
| Type | Description |
|---|---|
ndarray
|
Boolean array of shape |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If called before the first |
rebuild(config: ShieldConfig) -> None
Recompute the shield with a new ShieldConfig, allowing to change the threshold mid-training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ShieldConfig
|
New shield configuration. |
required |
tempestpy.shielding.PostShieldWrapper(env: gym.Env, factory: ShieldFactory, config: ShieldConfig, *, obs_to_values: ObsToValuesFn)
Bases: Wrapper
Gymnasium wrapper that enforces shield safety via post-shielding.
The agent selects any action from the unmasked observation. Before the
action reaches env.step, the wrapper checks the shield: if the action
is allowed it passes through unchanged; if blocked, ShieldConfig.post_selector
is called to supply a safe replacement.
The info dict after reset() / step() contains:
"shield_state_id"—int, state id used for the next step"shield_safe_action"—int, the action actually passed toenv.step"shield_corrected"—bool,Truewhen the agent's action was replaced
See Also
PreShieldWrapper : Injects an action mask before the agent acts. make_best_value_selector : A ready-made post-selector for this wrapper.
Initialise a PostShieldWrapper.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Env
|
The Gymnasium environment to wrap. Must have a
|
required |
factory
|
ShieldFactory
|
A |
required |
config
|
ShieldConfig
|
|
required |
obs_to_values
|
ObsToValuesFn
|
Callable |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
ValueError
|
If |
rebuild(config: ShieldConfig) -> None
Recompute the shield with a new ShieldConfig, allowing to change the threshold mid-training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ShieldConfig
|
New configuration. Must have a non- |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
tempestpy.shielding.make_best_value_selector(factory: ShieldFactory, result: ShieldResult, *, is_minimize: bool = False) -> Callable[[int, int], int]
Create a post-selector that replaces a blocked action with the best-value action.
The returned callable closes over factory to retrieve per-state choice values at call time. Works in both normal and dangerous states; in dangerous states it falls back to the argmax/argmin action.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
factory
|
ShieldFactory
|
The |
required |
result
|
ShieldResult
|
The |
required |
is_minimize
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
Callable[[int, int], int]
|
|
Examples:
>>> selector = make_best_value_selector(factory, result, is_minimize=False)
>>> config = ShieldConfig(threshold=0.9, post_selector=selector)
tempestpy.shielding.make_random_safe_selector(result: ShieldResult, *, rng: Optional[np.random.Generator] = None) -> Callable[[int, int], int]
Create a post-selector that replaces a blocked action with a random allowed action.
In dangerous states (no action meets the threshold) the fallback action
from ShieldResult.fallback_by_state is used instead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
ShieldResult
|
The |
required |
rng
|
Optional[Generator]
|
Optional NumPy random generator. If |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[int, int], int]
|
|
Examples:
>>> selector = make_random_safe_selector(result, rng=np.random.default_rng(42))
>>> config = ShieldConfig(threshold=0.9, post_selector=selector)
tempestpy.shielding.make_nearest_action_selector(result: ShieldResult) -> Callable[[int, int], int]
Create a post-selector that replaces a blocked action with the nearest allowed action.
Nearest is measured by absolute index distance, making this useful when
action indices have a natural ordering (e.g. discretised velocity).
Falls back to ShieldResult.fallback_by_state in dangerous states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
ShieldResult
|
The |
required |
Returns:
| Type | Description |
|---|---|
Callable[[int, int], int]
|
|
Examples:
>>> selector = make_nearest_action_selector(result)
>>> config = ShieldConfig(threshold=0.9, post_selector=selector)