Source code for lib.utils.classes

import numpy as np
import xml.etree.ElementTree as ET
from lib.utils.xmlread import XMLReader
from pathlib import Path
from functools import partial
import jax
import jax.numpy as jnp
# Base class for any problem defined


[docs] class ProblemObjectBase: """ Abstract base class for defining optimization problems in the OSParamFitting framework. This class provides a common interface for all problem types, including ODE systems, parameter estimation problems, and other optimization tasks. It defines the basic structure and methods that concrete problem implementations must provide. The base class handles common attributes like parameter names, initial conditions, time evaluation points, and dataset information. Concrete implementations should override the abstract methods to provide specific problem-solving logic. Attributes: params_to_fit_names (dict): Dictionary mapping parameter names to their indices y0 (numpy.ndarray, optional): Initial conditions for unsteady problems t_eval (numpy.ndarray, optional): Time points for solution evaluation dataset (numpy.ndarray, optional): Experimental or reference data to fit against num_columns_to_fit (int, optional): Number of data columns to fit fixed_params_names (list, optional): Names of parameters that are held constant fixed_param_values (list, optional): Values of the fixed parameters """
[docs] def __init__(self): """ Initialize the ProblemObjectBase with default attribute values. Sets up the basic structure for problem objects with empty or None values that should be populated by concrete implementations or external configuration. """ self.params_to_fit_names = {} # unsteady problems self.y0 = None self.t_eval = None self.dataset = None self.num_columns_to_fit=None self.params_to_fit_names=None self.fixed_params_names=None self.fixed_param_values=None
[docs] def integrate_system(self, *args): """ Integrate the system for given parameters and return the solution. This is a public interface method that delegates to the concrete implementation's _integrate_system method. Args: *args: Variable arguments passed to the concrete integration method Returns: The result of the system integration (typically solution trajectories) Note: Concrete implementations must override _integrate_system to provide actual integration logic. """ return self._integrate_system(*args)
#@partial(jax.jit,static_argnums=(0,))
[docs] def compute_loss(self, *args): """ Compute the loss/objective function value for given parameters. This is a public interface method that delegates to the concrete implementation's _compute_loss method. The loss function quantifies how well the model fits the experimental data. Args: *args: Variable arguments (typically parameter values) passed to the loss computation Returns: float: The computed loss value (lower values indicate better fits) """ return self._compute_loss(*args)
[docs] def compute_all_losses(self, population_points: np.ndarray): """ Compute loss values for an entire population of parameter sets. This method evaluates the loss function for multiple parameter combinations, which is useful for population-based optimization algorithms like PSO. Args: population_points (numpy.ndarray): Array of parameter sets with shape (n_individuals, n_parameters) Returns: numpy.ndarray: Array of loss values for each parameter set Note: Concrete implementations must override _compute_all_losses to provide efficient batch loss computation logic. """ return self._compute_all_losses(population_points)
[docs] def plot_result(self, design_point: np.ndarray, label: str="default"): """ Plot the result for a given parameter set. This method provides visualization capabilities for the problem results, allowing users to inspect the quality of fits and parameter estimates. Args: design_point (numpy.ndarray): Parameter set to visualize label (str, optional): Label for the plot. Defaults to "default" Returns: The result of the plotting operation (typically matplotlib figure or None) Note: Concrete implementations must override _plot_result to provide actual plotting logic. This method is useful for debugging and result analysis. """ return self._plot_result(design_point, label)