Skip to content

ShieldFactory & ShieldConfig

A shield introduces per-state constraints: for every state in an MDP or SMG it records which actions are safe enough to take, expressed as a bitmask. A ShieldFactory builds the necessary model and computes the probabilities to stay safe for each state and action pair. It does so by running probabilistic model checking on a PRISM model. To build a concrete shield, you specify the safety threshold via the ShieldConfig.


Quick start

from tempestpy.shielding import ShieldFactory, ShieldConfig

factory = ShieldFactory(
    "model.prism",
    property="Pmax=? [G safe]",
    constants={"N": 5, "slippery": False},
)
config  = ShieldConfig(threshold=0.9)
result  = factory.build(config)

factory.build(config) runs model checking once, caches the raw choice values, and returns an ExplicitShieldResult (or SymbolicShieldResult for engine="dd").


ShieldFactory

Constructor

Parameter Type Description
model_path str Path to a .prism file.
property str PCTL/PRCTL property, e.g. "Pmax=? [G safe]".
constants dict PRISM constants, e.g. {"N": 5, "slippery": False}.
check_result stormpy result Pre-computed check result — skips model checking.
engine "sparse" | "dd" Backend. "sparse" (default) for explicit-state; "dd" for symbolic.

Using a Different Property

To use a different property, you can copy the factory:

# Switch to a different safety objective without rebuilding the model:
factory.prop = "Pmin=? [F crash]"
result2 = factory.build(ShieldConfig(threshold=0.05, comparison="absolute"))

or use with_prop:

factory2 = factory.with_prop("Pmin=? [F crash]")

Querying state information

# Map PRISM variable values to a state id
sid = factory.get_state_id({"x": 3, "y": 1, "done": False})

# Get the raw model-checked choice values for that state
pvals = factory.get_choice_values_for_state(sid)  # np.ndarray, one entry per action

Saving and loading

# Slim save: only choice values and state lookup (~fastest, smallest file)
factory.save("shield.pkl")

# Full save: also embeds the DRN-serialised model (factory is fully self-contained)
factory.save("shield_full.pkl", slim=False)

# Load back
factory2 = ShieldFactory.load("shield.pkl")
result2 = factory2.compute(config)   # reuse cached values; no model checking needed

Slim vs full

slim=True (default) saves only the pre-computed choice values and the state lookup table. The loaded factory can call compute() but not build() (the model is absent). Use slim=False when you need the loaded factory to be able to re-run model checking.


ShieldConfig

A ShieldConfig specifies how a concrete shield should be built using a ShieldFactory.

from tempestpy.shielding import ShieldConfig

config = ShieldConfig(
    threshold=0.9,         # default: 1.0
    comparison="relative", # default: "absolute"
)

Threshold and comparison mode

absolute (default): action i is allowed when its model-checked value ≥ threshold (for Pmax) or ≤ threshold (for Pmin).

# Allow any action achieving at least 80 % safety probability
config = ShieldConfig(threshold=0.8, comparison="absolute")

relative: threshold is scaled by the per-state best value before comparison. An action is allowed when its value ≥ threshold × best_value_in_state.

# Allow actions within 5 % of the best achievable value in each state
config = ShieldConfig(threshold=0.95, comparison="relative")

post_selector

A post_selector is only needed for post-shielding. It is a callable (state_id, blocked_action) -> replacement_action invoked when the agent picks a blocked action. See the Wrappers page for ready-made selectors.

action_lookup

By default bitmask bit i corresponds to local action index i in the PRISM model. When the RL environment uses a different global action numbering, supply an ActionLookup:

from tempestpy.shielding.action_lookup import ActionLookup

lookup = ActionLookup({"left": 0, "right": 1, "up": 2, "down": 3})
config = ShieldConfig(threshold=0.9, action_lookup=lookup)

Bit i in the resulting bitmask then corresponds to global RL action index i.


ShieldResult

build / compute return a ShieldResult subclass. For the sparse engine that is ExplicitShieldResult (states are integer ids); for engine="dd" it is SymbolicShieldResult (states are PRISM variable dicts).

Querying the mask

# Pre-shielding: get the bitmask for a state (falls back when nothing passes)
bits = result.query_mask(sid)                        # int; bit i = action i is allowed

# Post-shielding: pass agent's chosen action through the shield
safe = result.query_post(sid, agent_action)          # returns agent_action or selector result

# Check a single action
allowed = result.is_action_allowed(sid, action_index)  # bool
In most cases, you won't need to query the shield manually, the Wrappers handle this for you automatically.