Skip to content

API Reference

tempestpy.shielding.ShieldFactory(model_path: str, *, property: Optional[str] = None, constants: Optional[Dict[str, Any]] = None, check_result: Optional[Any] = None, engine: str = 'sparse')

Entry point for computing probabilistic shields from PRISM models.

A ShieldFactory parses a PRISM MDP/SMG, runs model checking to obtain per-action choice values, and converts those values into ShieldResult bitmasks via build or compute.

Parameters:

Name Type Description Default
model

Path string to a .prism file

required
property Optional[str]

A PCTL/PRCTL property string, e.g. "Pmax=? [G safety]".

None
constants Optional[Dict[str, Any]]

PRISM constant definitions, e.g. {"N": 3, "slippery": False}.

None
check_result Optional[Any]

A pre-computed stormpy check result that already contains choice values. When provided, model checking is skipped entirely.

None
engine str

Backend representation. "sparse" (default) uses a sparse explicit-state model; "dd" uses a symbolic BDD-based model.

'sparse'

Raises:

Type Description
ValueError

If engine is not "sparse" or "dd".

TypeError

If the model type or engine/model combination is unsupported.

Examples:

>>> factory = ShieldFactory("model.prism", property="Pmax=? [G safe]",
...                         constants={"N": 5})
>>> result = factory.build(ShieldConfig(threshold=0.9))

model: Any property

The stormpy model held by this factory.

prism_program: Any property

The parsed storm::prism::Program held by this factory.

prop: Optional[str] property writable

The active PCTL property string, or None if not set.

formula: Any property

The parsed stormpy formula, or None before a property is set.

optimality_type: Any property

The optimality direction (Minimize / Maximize) derived from the formula, or None if no formula has been set.

is_sparse: bool property

True when the factory operates in sparse-engine mode.

state_lookup: StateValuationLookup property

Lazy-initialised bidirectional map between state ids and PRISM valuations.

with_prop(prop: str) -> 'ShieldFactory'

Return a copy of this factory with a different property string.

build(config: ShieldConfig) -> Any

Run model checking (if needed) and compute the ShieldResult.

If model checking has not been performed yet, it is triggered here. Subsequent calls with the same property reuse the cached result.

Parameters:

Name Type Description Default
config ShieldConfig

Threshold, comparison mode, and optional post-selector / action lookup settings for the ShieldResult computation.

required

Returns:

Type Description
ShieldResult

Per-state bitmasks encoding which actions are permitted.

Raises:

Type Description
RuntimeError

If neither a formula nor a check result is available.

compute(config: ShieldConfig) -> Any

Compute the ShieldResult from an already-available check result.

Unlike build, this method does not run model checking. A check result must be present.

Parameters:

Name Type Description Default
config ShieldConfig

Threshold, comparison mode, and optional post-selector / action lookup settings.

required

Returns:

Type Description
ShieldResult

Per-state bitmasks encoding which actions are permitted.

Raises:

Type Description
RuntimeError

If no check result is available.

get_choice_values_for_state(sid: int) -> np.ndarray

Return the model-checked choice values for all actions in state sid.

get_state_id(values: dict) -> int

Look up the state id for a given PRISM variable valuation.

Parameters:

Name Type Description Default
values dict

Mapping of PRISM variable name to value, e.g. {"x": 2, "y": 0, "done": False}.

required

Returns:

Type Description
int

The model state id corresponding to values.

Raises:

Type Description
KeyError

If the valuation does not match any state.

dangerous_states(*, shield_result: ShieldResult) -> list[dict]

Return the PRISM valuations of all dangerous states in shield_result.

A state is dangerous when no action meets the shield threshold and the fallback (best-value) action is used instead.

critical_states(*, shield_result: ShieldResult) -> list[dict]

Return the PRISM valuations of all critical states in shield_result.

A state is critical when at least one action is blocked by the threshold but at least one action still passes.

safe_states(*, shield_result: ShieldResult) -> list[dict]

Return the PRISM valuations of all safe states in shield_result.

A state is safe when every action passes the shield threshold.

states_by_class(*, shield_result: ShieldResult) -> dict[str, list[dict]]

Partition all states into the four shield classes with their valuations.

save(path: str, *, slim: bool = True) -> None

Persist this factory to a pickle file.

Parameters:

Name Type Description Default
path str

Destination file path.

required
slim bool

True (default): persist only choice values and state lookup — smallest file. False: also embed the DRN-serialised Storm model and PRISM path, making the factory fully functional after load.

True

load(path: str) -> 'ShieldFactory' classmethod

Load a ShieldFactory from a pickle file produced by save.

Notes

If saved with slim=True: only choice values and state lookup are available; the Storm model is not restored. If saved with slim=False: the Storm model is restored from DRN; call build() to re-run model checking if needed.

to_bitmask(shield_result: Any) -> dict[int, Bitmask]

Convert shield_result to a {state_id: bitmask} dictionary.

to_allowed_action_indices(shield_result: Any) -> dict[int, list[int]]

Convert shield_result to a {state_id: [action_index, ...]} dictionary.

to_allowed_action_labels(shield_result: Any) -> dict[int, list]

Convert shield_result to a {state_id: [frozenset(labels), ...]} dictionary.

to_allowed_state_action_pairs(shield_result: Any) -> list[tuple[int, int]]

Return all allowed (state_id, local_action_index) pairs.

to_valuation_bitmask(shield_result: Any) -> dict[dict, Bitmask]

Convert shield_result to a {valuation_key: bitmask} dictionary.

to_valuation_allowed_action_labels(shield_result: Any) -> dict

Convert shield_result to a {valuation_key: [frozenset(labels), ...]} dictionary.

to_valuation_action_label_probability() -> dict

Map every state to all its actions and their model-checked probabilities: {valuation_key: {frozenset(action_labels): probability}}

pretty(shield_result: Any, *, max_states: int = 25, show_valuations: bool = True, action_ref: str = 'labels') -> str

Return a human-readable summary of shield_result for debugging.

Parameters:

Name Type Description Default
shield_result Any

A ShieldResult.

required
max_states int

Maximum number of states to include in the output (default 25).

25
show_valuations bool

If True (default), show PRISM variable values next to each state.

True
action_ref str

How to identify actions: "labels" (default) shows frozenset(labels); "index" shows local action indices.

'labels'

Returns:

Type Description
str

Multi-line human-readable shield summary.

to_storm_format(shield_result: Any) -> str

Render shield_result as a Storm-compatible shield text.

Parameters:

Name Type Description Default
shield_result Any

A ShieldResult.

required

Returns:

Type Description
str

Multi-line string matching Storm's shield output format, with per-action choice values and valuation annotations.

tempestpy.shielding.ShieldConfig(threshold: float = 1.0, comparison: ComparisonMode = 'absolute', post_selector: Optional[PostSelectorFn] = None, action_lookup: Optional[ActionLookup] = None) dataclass

Immutable configuration for a single shield computation.

Attributes:

Name Type Description
threshold float

The probability threshold. In "absolute" mode an action is allowed when its choice value is >= threshold (maximise) or <= threshold (minimise). In "relative" mode the threshold is multiplied by the per-state best value before comparison. Defaults to 1.0 (only optimal actions pass).

comparison ComparisonMode

How to interpret threshold: "absolute" compares directly against the model-checked probability; "relative" compares against threshold * best_value for the current state.

post_selector Optional[PostSelectorFn]

Optional callable invoked by ShieldResult.query_post when the agent's chosen action is blocked.

action_lookup Optional[ActionLookup]

Optional ActionLookup that maps PRISM action labels to global RL action indices. When set, bitmask bits correspond to global indices rather than local model action indices.

tempestpy.shielding.ShieldResult(config: Any, bit_width: Optional[int] = None, coalition_states: Optional[frozenset] = None) dataclass

Bases: ABC

Base class for shield results produced by explicit and symbolic backends.

Explicit shields are keyed by integer state ids. Symbolic shields are keyed by PRISM variable valuations, for example {"x": 4, "y": 7}.

query_mask(state: Any) -> int

Return the effective pre-shielding bitmask for state.

Bit i is set when action i is permitted. If no action meets the threshold (raw bitmask is 0), the fallback bitmask is returned instead, so the result is always non-zero unless the state is truly unsafe.

query_post(state: Any, agent_action: int, **kwargs: Any) -> int

Return agent_action if it is allowed by the shield, otherwise delegate to ShieldConfig.post_selector.

Parameters:

Name Type Description Default
state Any

Current state key (int for explicit, dict for symbolic).

required
agent_action int

The action index the agent wants to take.

required
**kwargs Any

Forwarded verbatim to the post-selector (e.g. q_values).

{}

Returns:

Type Description
int

The action index to actually execute.

Raises:

Type Description
TypeError

If the action is blocked and no post_selector is configured.

query_action(state: Any, agent_action: int, **kwargs: Any) -> int

Alias for query_post for callers that prefer action-oriented naming.

is_action_allowed(state: Any, action_index: int) -> bool

Return True when action_index passes the shield threshold at state.

is_coalition_state(sid: int) -> bool

Return True when sid belongs to the shielded coalition.

For plain MDPs (no SMG), always returns True. For SMGs, returns True only for ego-player states; opponent states are never blocked.

is_unsafe(state: Any) -> bool

Return True when no action is available and no fallback exists.

Both the threshold bitmask and the fallback bitmask are zero, meaning the shield has no safe or even best-effort action to offer.

is_dangerous(state: Any) -> bool

Return True when no action met the threshold and the fallback is active.

The fallback bitmask is non-zero, so the shield can still suggest an action, but none of them satisfied the safety threshold.

is_critical(state: Any) -> bool

Return True when at least one action is blocked but at least one passes.

The threshold bitmask is non-zero and fewer than all actions are permitted, with no fallback active.

is_safe(state: Any) -> bool

Return True when every available action passes the shield threshold.

The threshold bitmask covers all actions and no fallback is active.

classify_state(state: Any) -> str

Classify state into one of four shield categories.

Returns:

Type Description
str

"safe" — all actions allowed; "critical" — some actions blocked, at least one allowed; "dangerous" — no action meets the threshold, fallback used; "unsafe" — no action available and no fallback.

tempestpy.shielding.ExplicitShieldResult(config: Any, bit_width: Optional[int] = None, coalition_states: Optional[frozenset] = None, bitmask_by_state: list[Bitmask] = list(), fallback_by_state: list[Bitmask] = list(), nr_actions_by_state: list[int] = list(), best_value_by_state: Optional[np.ndarray] = None) dataclass

Bases: ShieldResult

Shield result produced by the sparse explicit-state backend (engine="sparse").

States are indexed by integer state id. All query methods accept an int and index into the per-state lists directly.

Attributes:

Name Type Description
bitmask_by_state list[Bitmask]

Per-state allowed-action bitmask. Bit i is set when action i meets the shield threshold.

fallback_by_state list[Bitmask]

Per-state fallback bitmask used when no action meets the threshold (the best available action, or 0 when none exists).

nr_actions_by_state list[int]

Total number of actions available in each state (used for is_safe / is_critical classification).

best_value_by_state Optional[ndarray]

Optional array of per-state best model-checked values, filled when the factory computes best-value selectors.

tempestpy.shielding.SymbolicShieldResult(config: Any, bit_width: Optional[int] = None, coalition_states: Optional[frozenset] = None, bitmask_add: Any = None, fallback_add: Any = None, nr_actions: int = 0) dataclass

Bases: ShieldResult

Shield result produced by the DD-backed symbolic backend (engine="dd").

States are identified by PRISM variable valuations rather than integer ids, e.g. {"x": 4, "y": 7}. All query methods accept a dict and evaluate the underlying algebraic decision diagrams (ADDs) at that point.

Attributes:

Name Type Description
bitmask_add Add

ADD encoding the allowed-action :data:Bitmask over the full state space. Replaces the list[Bitmask] of :class:ExplicitShieldResult; queried via bitmask_add.query(valuation).

fallback_add (Add, optional)

ADD encoding the fallback :data:Bitmask, or None when no fallback was computed (treated as 0 at every state).

nr_actions int

Uniform action count for the model, used by is_safe / is_critical. When 0, the bit-length of the bitmask is used as a conservative estimate.


Gymnasium Wrappers

tempestpy.shielding.PreShieldWrapper(env: gym.Env, factory: ShieldFactory, config: ShieldConfig, *, obs_to_values: ObsToValuesFn)

Bases: Wrapper

Gymnasium wrapper that adds pre-shield action masking.

On every reset() and step(), the wrapper queries the shield and injects an "action_mask" boolean array into info. The mask can be consumed directly or via the action_masks() hook for SB3 MaskablePPO.

The info dict after reset() / step() contains:

  • "action_mask"np.ndarray[bool], shape (n_actions,)
  • "shield_state_id"int, state id used for the current mask
  • "shield_bitmask"int, raw packed bitmask
See Also

PostShieldWrapper : Corrects the agent's chosen action after the fact. ShieldFactory.build : Produces the ShieldResult used internally.

Initialise a PreShieldWrapper.

Parameters:

Name Type Description Default
env Env

The Gymnasium environment to wrap. Must have a gym.spaces.Discrete action space.

required
factory ShieldFactory

A ShieldFactory with a built model and property.

required
config ShieldConfig

ShieldConfig specifying the threshold and comparison mode. post_selector is not used here.

required
obs_to_values ObsToValuesFn

Callable (obs, info) -> {prism_var: value} that maps an observation (and optionally info) to the PRISM state variables needed to look up the shield state id.

required

Raises:

Type Description
TypeError

If env.action_space is not gym.spaces.Discrete.

action_masks() -> np.ndarray

Return the current boolean action mask.

Returns:

Type Description
ndarray

Boolean array of shape (n_actions,).

Raises:

Type Description
RuntimeError

If called before the first reset().

rebuild(config: ShieldConfig) -> None

Recompute the shield with a new ShieldConfig, allowing to change the threshold mid-training.

Parameters:

Name Type Description Default
config ShieldConfig

New shield configuration.

required

tempestpy.shielding.PostShieldWrapper(env: gym.Env, factory: ShieldFactory, config: ShieldConfig, *, obs_to_values: ObsToValuesFn)

Bases: Wrapper

Gymnasium wrapper that enforces shield safety via post-shielding.

The agent selects any action from the unmasked observation. Before the action reaches env.step, the wrapper checks the shield: if the action is allowed it passes through unchanged; if blocked, ShieldConfig.post_selector is called to supply a safe replacement.

The info dict after reset() / step() contains:

  • "shield_state_id"int, state id used for the next step
  • "shield_safe_action"int, the action actually passed to env.step
  • "shield_corrected"bool, True when the agent's action was replaced
See Also

PreShieldWrapper : Injects an action mask before the agent acts. make_best_value_selector : A ready-made post-selector for this wrapper.

Initialise a PostShieldWrapper.

Parameters:

Name Type Description Default
env Env

The Gymnasium environment to wrap. Must have a gym.spaces.Discrete action space.

required
factory ShieldFactory

A ShieldFactory with a built model and property.

required
config ShieldConfig

ShieldConfig with a non-None post_selector.

required
obs_to_values ObsToValuesFn

Callable (obs, info) -> {prism_var: value} that maps an observation to the PRISM state variables needed for shield lookup.

required

Raises:

Type Description
TypeError

If env.action_space is not gym.spaces.Discrete.

ValueError

If config.post_selector is None.

rebuild(config: ShieldConfig) -> None

Recompute the shield with a new ShieldConfig, allowing to change the threshold mid-training.

Parameters:

Name Type Description Default
config ShieldConfig

New configuration. Must have a non-None post_selector.

required

Raises:

Type Description
ValueError

If config.post_selector is None.

tempestpy.shielding.make_best_value_selector(factory: ShieldFactory, result: ShieldResult, *, is_minimize: bool = False) -> Callable[[int, int], int]

Create a post-selector that replaces a blocked action with the best-value action.

The returned callable closes over factory to retrieve per-state choice values at call time. Works in both normal and dangerous states; in dangerous states it falls back to the argmax/argmin action.

Parameters:

Name Type Description Default
factory ShieldFactory

The ShieldFactory used to look up choice values.

required
result ShieldResult

The ShieldResult (currently unused; kept for a consistent selector API).

required
is_minimize bool

If True, select the action with the lowest choice value. Default is False (maximise).

False

Returns:

Type Description
Callable[[int, int], int]

selector(state_id, blocked_action) -> replacement_action.

Examples:

>>> selector = make_best_value_selector(factory, result, is_minimize=False)
>>> config = ShieldConfig(threshold=0.9, post_selector=selector)

tempestpy.shielding.make_random_safe_selector(result: ShieldResult, *, rng: Optional[np.random.Generator] = None) -> Callable[[int, int], int]

Create a post-selector that replaces a blocked action with a random allowed action.

In dangerous states (no action meets the threshold) the fallback action from ShieldResult.fallback_by_state is used instead.

Parameters:

Name Type Description Default
result ShieldResult

The ShieldResult used to look up allowed actions.

required
rng Optional[Generator]

Optional NumPy random generator. If None, a fresh default generator is created.

None

Returns:

Type Description
Callable[[int, int], int]

selector(state_id, blocked_action) -> replacement_action.

Examples:

>>> selector = make_random_safe_selector(result, rng=np.random.default_rng(42))
>>> config = ShieldConfig(threshold=0.9, post_selector=selector)

tempestpy.shielding.make_nearest_action_selector(result: ShieldResult) -> Callable[[int, int], int]

Create a post-selector that replaces a blocked action with the nearest allowed action.

Nearest is measured by absolute index distance, making this useful when action indices have a natural ordering (e.g. discretised velocity). Falls back to ShieldResult.fallback_by_state in dangerous states.

Parameters:

Name Type Description Default
result ShieldResult

The ShieldResult used to look up allowed actions.

required

Returns:

Type Description
Callable[[int, int], int]

selector(state_id, blocked_action) -> replacement_action.

Examples:

>>> selector = make_nearest_action_selector(result)
>>> config = ShieldConfig(threshold=0.9, post_selector=selector)