# scisbi/base/inference_algorithm.pyimportabcfromtypingimportAny,Dict,Optional# Assume these base classes are defined elsewhere in your libraryfrom.simulatorimportBaseSimulatorfrom.summary_statisticimportBaseSummaryStatistic# New optional dependency
[docs]classBaseInferenceAlgorithm(abc.ABC):""" Abstract base class for all Simulation-Based Inference algorithms. Concrete implementations should inherit from this class and implement the abstract methods. This base class handles the storage of fundamental components like the simulator, prior, and optional summary statistic. """def__init__(self,simulator:BaseSimulator,prior:Any,summary_statistic:Optional[BaseSummaryStatistic]=None,**kwargs:Any,):""" Initializes the inference algorithm. Args: simulator (BaseSimulator): An instance of a simulator object. Must have a 'simulate' method. prior (Any): An instance of a prior distribution object. Must have 'log_prob' and 'sample' methods. summary_statistic (Optional[BaseSummaryStatistic]): An optional summary statistic object. If provided, it should have a 'compute' method to reduce data dimensionality. Defaults to None. **kwargs: Additional algorithm-specific configuration parameters. These settings often control aspects like neural network architecture hyperparameters, training parameters (epochs, batch size), number of rounds (for sequential methods), acceptance thresholds (for ABC), etc. """# Basic checks for required componentsifnothasattr(simulator,"simulate")ornotcallable(simulator.simulate):raiseTypeError("simulator must be an instance of a class with a 'simulate' method")ifnothasattr(prior,"log_prob")ornotcallable(prior.log_prob):raiseTypeError("prior must be an instance of a class with a 'log_prob' method")ifnothasattr(prior,"sample")ornotcallable(prior.sample):raiseTypeError("prior must be an instance of a class with a 'sample' method")ifsummary_statisticisnotNoneand(nothasattr(summary_statistic,"compute")ornotcallable(summary_statistic.compute)):raiseTypeError("summary_statistic, if provided, must be an instance of a class with a 'compute' method")self.simulator:BaseSimulator=simulatorself.prior:Any=priorself.summary_statistic:Optional[BaseSummaryStatistic]=summary_statisticself.settings:Dict[str,Any]=kwargs# Store algorithm-specific config
[docs]@abc.abstractmethoddefinfer(self,observed_data:Any,# Type depends on simulator/summary_statistic outputnum_simulations:int,**kwargs:Any,)->Any:""" Runs the simulation-based inference process. This method orchestrates the simulation, training (if applicable), and posterior estimation steps. The implementation will vary significantly between different SBI algorithms (e.g., ABC, SNPE, SNL, SRE). Args: observed_data (Any): The observed data point(s) to perform inference on. If a summary statistic was provided during initialization, this data might be the raw observed data which will be transformed internally using the summary statistic. Otherwise, its type/format depends on the simulator's output structure. num_simulations (int): The total number of simulations the algorithm is allowed to run across potentially multiple rounds. **kwargs: Algorithm-specific parameters for this specific inference run. Examples might include: number of training epochs for a single round, specific random seeds, control over output verbosity for this run. These override settings from __init__. Returns: Any: An object representing the estimated posterior distribution. This object should provide methods to query the posterior, e.g., `sample`, `log_prob` (if density is available). Raises: NotImplementedError: If the method is not implemented by a concrete class. Any other exception relevant to the specific algorithm's execution. """# Concrete implementations will typically perform steps like:# 1. Get observed data (and apply summary statistic if available).# 2. Sample parameters from the prior (or an updated proposal for sequential methods).# 3. Run the simulator with these parameters to get simulated data.# 4. (If summary statistic is used) Compute summary statistics of simulated data.# 5. Use the simulated (parameters, data/summaries) pairs to train a model# (e.g., density estimator, classifier) which is method-specific.# 6. Use the trained model and the observed data (or its summaries) to# construct or sample from the estimated posterior distribution.# 7. Return an object representing this posterior.raiseNotImplementedError("Not implemented inference. Call originates from abstract base method")