Skip to content

metrics

plots

get_confusion_matrix(y_true, y_pred, display_labels, name='')

Generate and display a confusion matrix plot.

Parameters:

Name Type Description Default
y_true Collection

True labels.

required
y_pred Collection

Predicted labels.

required
display_labels list

Labels to be displayed on the confusion matrix plot.

required
name str

Name of the plot. Defaults to an empty string.

''

Returns:

Name Type Description
ConfusionMatrixDisplay ConfusionMatrixDisplay

A display object for the confusion matrix plot.

Source code in aimet_ml/metrics/plots.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def get_confusion_matrix(
    y_true: Collection, y_pred: Collection, display_labels: list, name: str = ""
) -> ConfusionMatrixDisplay:
    """
    Generate and display a confusion matrix plot.

    Args:
        y_true (Collection): True labels.
        y_pred (Collection): Predicted labels.
        display_labels (list): Labels to be displayed on the confusion matrix plot.
        name (str, optional): Name of the plot. Defaults to an empty string.

    Returns:
        ConfusionMatrixDisplay: A display object for the confusion matrix plot.
    """
    cm_display = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=display_labels)
    cm_display.plot()
    cm_display.ax_.set_title(name)
    return cm_display

get_prc_display(y_true, y_score, name='')

Generate and display a precision-recall curve plot.

Parameters:

Name Type Description Default
y_true Collection

True labels.

required
y_score Collection

Predicted scores or probabilities.

required
name str

Name of the plot. Defaults to an empty string.

''

Returns:

Name Type Description
PrecisionRecallDisplay PrecisionRecallDisplay

A display object for the precision-recall curve plot.

Source code in aimet_ml/metrics/plots.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def get_prc_display(y_true: Collection, y_score: Collection, name: str = "") -> PrecisionRecallDisplay:
    """
    Generate and display a precision-recall curve plot.

    Args:
        y_true (Collection): True labels.
        y_score (Collection): Predicted scores or probabilities.
        name (str, optional): Name of the plot. Defaults to an empty string.

    Returns:
        PrecisionRecallDisplay: A display object for the precision-recall curve plot.
    """
    pr_display = PrecisionRecallDisplay.from_predictions(y_true, y_score, name=name)
    pr_display.ax_.set_title(name)
    return pr_display

get_roc_display(y_true, y_score, name='')

Generate and display a ROC curve plot.

Parameters:

Name Type Description Default
y_true Collection

True labels.

required
y_score Collection

Predicted scores or probabilities.

required
name str

Name of the plot. Defaults to an empty string.

''

Returns:

Name Type Description
RocCurveDisplay RocCurveDisplay

A display object for the ROC curve plot.

Source code in aimet_ml/metrics/plots.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_roc_display(y_true: Collection, y_score: Collection, name: str = "") -> RocCurveDisplay:
    """
    Generate and display a ROC curve plot.

    Args:
        y_true (Collection): True labels.
        y_score (Collection): Predicted scores or probabilities.
        name (str, optional): Name of the plot. Defaults to an empty string.

    Returns:
        RocCurveDisplay: A display object for the ROC curve plot.
    """
    roc_display = RocCurveDisplay.from_predictions(y_true, y_score, name=name)
    roc_display.ax_.set_title(name)
    return roc_display

reports

add_metric_to_report(cls_report, metric_name, label_names, metric_values)

Adds metric values to a classification report.

Parameters:

Name Type Description Default
cls_report dict

The classification report as a dictionary.

required
metric_name str

The name of the metric to add.

required
label_names Collection[str]

Collection of label names.

required
metric_values Collection[Union[float, int]]

Collection of metric values corresponding to label_names.

required

Raises:

Type Description
AssertionError

If the lengths of label_names and metric_values do not match.

Source code in aimet_ml/metrics/reports.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def add_metric_to_report(
    cls_report: Dict[str, Dict[str, Any]],
    metric_name: str,
    label_names: Collection[str],
    metric_values: Collection[Union[float, int]],
) -> None:
    """
    Adds metric values to a classification report.

    Args:
        cls_report (dict): The classification report as a dictionary.
        metric_name (str): The name of the metric to add.
        label_names (Collection[str]): Collection of label names.
        metric_values (Collection[Union[float, int]]): Collection of metric values corresponding to label_names.

    Raises:
        AssertionError: If the lengths of label_names and metric_values do not match.
    """
    if len(label_names) != len(metric_values):
        raise ValueError('label_names and metric_values must have the same length')

    if len(set(label_names)) != len(label_names):
        raise ValueError('Elements in label_names must be distinct from each other')

    if any(label_name not in cls_report for label_name in label_names):
        raise ValueError('All elements in label_names must be in the target_names of cls_report')

    cls_report["macro avg"][metric_name] = 0
    cls_report["weighted avg"][metric_name] = 0

    for label_name, metric_value in zip(label_names, metric_values):
        macro_w = 1 / len(label_names)
        micro_w = cls_report[label_name]["support"] / cls_report["weighted avg"]["support"]

        cls_report[label_name][metric_name] = metric_value
        cls_report["macro avg"][metric_name] += metric_value * macro_w
        cls_report["weighted avg"][metric_name] += metric_value * micro_w

flatten_dict(d, prefix='')

Recursively flattens a nested dictionary.

Parameters:

Name Type Description Default
d dict

The input dictionary to flatten.

required
prefix str

The prefix to be added to flattened keys. Defaults to "".

''

Returns:

Name Type Description
dict Dict[str, Any]

A flattened dictionary.

Source code in aimet_ml/metrics/reports.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def flatten_dict(d: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
    """
    Recursively flattens a nested dictionary.

    Args:
        d (dict): The input dictionary to flatten.
        prefix (str, optional): The prefix to be added to flattened keys. Defaults to "".

    Returns:
        dict: A flattened dictionary.
    """
    flat_dict = {}
    for key, value in d.items():
        new_key = f"{prefix}_{key}" if prefix else key
        if isinstance(value, dict):
            flat_dict.update(flatten_dict(value, prefix=new_key))
        else:
            flat_dict[new_key] = value
    return flat_dict