Skip to content

ParameterValidator

Validate optimized parameters with quality metrics and visualization.

Usage

from parametrizani import ParameterValidator

validator = ParameterValidator('./work')
validation = validator.validate_parameters(
    ref_angles, ref_energies, fitted_energies
)

print(f"Quality: {validation['quality']}")
print(f"RMSE: {validation['rmse']:.4f} kcal/mol")
print(f"MAE: {validation['mae']:.4f} kcal/mol")
print(f"R²: {validation['r_squared']:.4f}")

Quality Ratings

Rating RMSE (kcal/mol) Interpretation
Excellent ≤ 0.25 Near-QM accuracy
Good ≤ 0.50 Suitable for most applications
Acceptable ≤ 1.00 Usable with caution
Poor > 1.00 Consider different parameters

API Reference

parametrizani.validation.ParameterValidator

Validate optimized dihedral parameters.

Computes RMSE, MAE, R², correlation and generates reports/plots.

Parameters:

Name Type Description Default
work_dir str

Working directory for output files.

'./work'
Source code in parametrizani/validation.py
class ParameterValidator:
    """
    Validate optimized dihedral parameters.

    Computes RMSE, MAE, R\u00b2, correlation and generates reports/plots.

    Parameters
    ----------
    work_dir : str
        Working directory for output files.
    """

    QUALITY_THRESHOLDS = {
        'Excellent': 0.25,
        'Good': 0.50,
        'Acceptable': 1.00,
        'Poor': float('inf'),
    }

    def __init__(self, work_dir: str = './work'):
        self.work_dir = create_work_dir(work_dir)

    def validate_parameters(
        self, angles: List[float], ref_energies: List[float],
        fitted_energies: List[float], mm_energies: Optional[List[float]] = None,
        labels: Optional[Dict[str, str]] = None,
    ) -> Dict[str, Any]:
        """
        Validate fitted parameters against reference.

        Returns
        -------
        Dict[str, Any]
            'rmse', 'mae', 'r_squared', 'correlation', 'max_error', 'quality',
            'report_file', 'plot_file'.
        """
        ref = np.array(ref_energies)
        fitted = np.array(fitted_energies)

        rmse = np.sqrt(np.mean((ref - fitted) ** 2))
        mae = np.mean(np.abs(ref - fitted))
        max_error = np.max(np.abs(ref - fitted))

        ss_res = np.sum((ref - fitted) ** 2)
        ss_tot = np.sum((ref - np.mean(ref)) ** 2)
        r_squared = 1.0 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
        correlation = np.corrcoef(ref, fitted)[0, 1] if len(ref) > 1 else 0.0

        quality = 'Poor'
        for rating, threshold in self.QUALITY_THRESHOLDS.items():
            if rmse <= threshold:
                quality = rating
                break

        logger.info(f"Validation: RMSE={rmse:.4f}, R\u00b2={r_squared:.4f}, Quality={quality}")

        report = self._generate_report(
            angles, ref, fitted, mm_energies, rmse, mae, r_squared, correlation, max_error, quality
        )
        report_file = os.path.join(self.work_dir, 'validation_report.txt')
        with open(report_file, 'w') as f:
            f.write(report)

        plot_file = self._generate_plot(angles, ref, fitted, mm_energies, labels, quality, rmse)

        return {
            'rmse': rmse, 'mae': mae, 'r_squared': r_squared,
            'correlation': correlation, 'max_error': max_error,
            'quality': quality, 'report_file': report_file, 'plot_file': plot_file,
        }

    def compare_methods(self, angles, energy_profiles, reference_key=None):
        """Compare multiple energy profiles."""
        results = {}
        if reference_key is None:
            reference_key = list(energy_profiles.keys())[0]
        ref = np.array(energy_profiles[reference_key])
        for name, energies in energy_profiles.items():
            if name == reference_key:
                continue
            pred = np.array(energies)
            rmse = np.sqrt(np.mean((ref - pred) ** 2))
            mae = np.mean(np.abs(ref - pred))
            ss_res = np.sum((ref - pred) ** 2)
            ss_tot = np.sum((ref - np.mean(ref)) ** 2)
            r_squared = 1.0 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
            results[name] = {'rmse': rmse, 'mae': mae, 'r_squared': r_squared}
        plot_file = self._generate_comparison_plot(angles, energy_profiles, reference_key)
        return {'reference': reference_key, 'metrics': results, 'plot_file': plot_file}

    def _generate_report(self, angles, ref, fitted, mm_energies, rmse, mae, r_squared, correlation, max_error, quality):
        lines = [
            "=" * 60, "ParametrizANI - Validation Report", "=" * 60, "",
            f"Quality Rating: {quality}", "", "Metrics:",
            f"  RMSE:           {rmse:.4f} kcal/mol",
            f"  MAE:            {mae:.4f} kcal/mol",
            f"  Max Error:      {max_error:.4f} kcal/mol",
            f"  R\u00b2:             {r_squared:.4f}",
            f"  Correlation:    {correlation:.4f}", "",
            "Per-point comparison:",
            f"  {'Angle':>8s} {'Reference':>10s} {'Fitted':>10s} {'Error':>10s}",
        ]
        for angle, r, f in zip(angles, ref, fitted):
            lines.append(f"  {angle:8.1f} {r:10.4f} {f:10.4f} {abs(r-f):10.4f}")
        lines.extend(["", "=" * 60])
        return "\n".join(lines)

    def _generate_plot(self, angles, ref, fitted, mm_energies, labels, quality, rmse):
        try:
            import matplotlib
            matplotlib.use('Agg')
            import matplotlib.pyplot as plt
            plt.style.use('seaborn-v0_8-whitegrid')
            fig, ax = plt.subplots(1, 1, figsize=(10, 6))
            ref_label = labels.get('reference', 'Reference') if labels else 'Reference'
            fit_label = labels.get('fitted', 'Optimized') if labels else 'Optimized'
            ax.plot(angles, ref, 'o-', linewidth=1.5, markersize=4, label=ref_label, color='#2196F3')
            ax.plot(angles, fitted, 's-', linewidth=1.5, markersize=4, label=fit_label, color='#F44336')
            if mm_energies is not None:
                mm_label = labels.get('mm', 'MM (original)') if labels else 'MM (original)'
                ax.plot(angles, mm_energies, '^--', linewidth=1.0, markersize=3, label=mm_label, color='#4CAF50', alpha=0.7)
            ax.set_xlabel('Dihedral Angle (degrees)', fontsize=12)
            ax.set_ylabel('Relative Energy (kcal/mol)', fontsize=12)
            ax.set_title(f'Dihedral Parametrization - {quality} (RMSE: {rmse:.4f} kcal/mol)', fontsize=13)
            ax.legend(fontsize=11, frameon=True)
            ax.set_xticks(np.arange(min(angles), max(angles) + 1, 60))
            plot_file = os.path.join(self.work_dir, 'validation_plot.png')
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.close()
            return plot_file
        except ImportError:
            return ""

    def _generate_comparison_plot(self, angles, energy_profiles, reference_key):
        try:
            import matplotlib
            matplotlib.use('Agg')
            import matplotlib.pyplot as plt
            plt.style.use('seaborn-v0_8-whitegrid')
            fig, ax = plt.subplots(1, 1, figsize=(10, 6))
            colors = ['#2196F3', '#F44336', '#4CAF50', '#FF9800', '#9C27B0', '#00BCD4']
            for i, (name, energies) in enumerate(energy_profiles.items()):
                lw = 2.0 if name == reference_key else 1.5
                ax.plot(angles, energies, 'o-', linewidth=lw, markersize=4, label=name, color=colors[i % len(colors)])
            ax.set_xlabel('Dihedral Angle (degrees)', fontsize=12)
            ax.set_ylabel('Relative Energy (kcal/mol)', fontsize=12)
            ax.set_title('Energy Profile Comparison', fontsize=13)
            ax.legend(fontsize=11, frameon=True)
            plot_file = os.path.join(self.work_dir, 'comparison_plot.png')
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.close()
            return plot_file
        except ImportError:
            return ""

__init__

__init__(work_dir: str = './work')
Source code in parametrizani/validation.py
def __init__(self, work_dir: str = './work'):
    self.work_dir = create_work_dir(work_dir)

validate_parameters

validate_parameters(angles: List[float], ref_energies: List[float], fitted_energies: List[float], mm_energies: Optional[List[float]] = None, labels: Optional[Dict[str, str]] = None) -> Dict[str, Any]

Validate fitted parameters against reference.

Returns:

Type Description
Dict[str, Any]

'rmse', 'mae', 'r_squared', 'correlation', 'max_error', 'quality', 'report_file', 'plot_file'.

Source code in parametrizani/validation.py
def validate_parameters(
    self, angles: List[float], ref_energies: List[float],
    fitted_energies: List[float], mm_energies: Optional[List[float]] = None,
    labels: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
    """
    Validate fitted parameters against reference.

    Returns
    -------
    Dict[str, Any]
        'rmse', 'mae', 'r_squared', 'correlation', 'max_error', 'quality',
        'report_file', 'plot_file'.
    """
    ref = np.array(ref_energies)
    fitted = np.array(fitted_energies)

    rmse = np.sqrt(np.mean((ref - fitted) ** 2))
    mae = np.mean(np.abs(ref - fitted))
    max_error = np.max(np.abs(ref - fitted))

    ss_res = np.sum((ref - fitted) ** 2)
    ss_tot = np.sum((ref - np.mean(ref)) ** 2)
    r_squared = 1.0 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
    correlation = np.corrcoef(ref, fitted)[0, 1] if len(ref) > 1 else 0.0

    quality = 'Poor'
    for rating, threshold in self.QUALITY_THRESHOLDS.items():
        if rmse <= threshold:
            quality = rating
            break

    logger.info(f"Validation: RMSE={rmse:.4f}, R\u00b2={r_squared:.4f}, Quality={quality}")

    report = self._generate_report(
        angles, ref, fitted, mm_energies, rmse, mae, r_squared, correlation, max_error, quality
    )
    report_file = os.path.join(self.work_dir, 'validation_report.txt')
    with open(report_file, 'w') as f:
        f.write(report)

    plot_file = self._generate_plot(angles, ref, fitted, mm_energies, labels, quality, rmse)

    return {
        'rmse': rmse, 'mae': mae, 'r_squared': r_squared,
        'correlation': correlation, 'max_error': max_error,
        'quality': quality, 'report_file': report_file, 'plot_file': plot_file,
    }