metrics
plots
¶
get_confusion_matrix(y_true, y_pred, display_labels, name='')
¶
Generate and display a confusion matrix plot.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y_true |
Collection
|
True labels. |
required |
y_pred |
Collection
|
Predicted labels. |
required |
display_labels |
list
|
Labels to be displayed on the confusion matrix plot. |
required |
name |
str
|
Name of the plot. Defaults to an empty string. |
''
|
Returns:
Name | Type | Description |
---|---|---|
ConfusionMatrixDisplay |
ConfusionMatrixDisplay
|
A display object for the confusion matrix plot. |
Source code in aimet_ml/metrics/plots.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
|
get_prc_display(y_true, y_score, name='')
¶
Generate and display a precision-recall curve plot.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y_true |
Collection
|
True labels. |
required |
y_score |
Collection
|
Predicted scores or probabilities. |
required |
name |
str
|
Name of the plot. Defaults to an empty string. |
''
|
Returns:
Name | Type | Description |
---|---|---|
PrecisionRecallDisplay |
PrecisionRecallDisplay
|
A display object for the precision-recall curve plot. |
Source code in aimet_ml/metrics/plots.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
|
get_roc_display(y_true, y_score, name='')
¶
Generate and display a ROC curve plot.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y_true |
Collection
|
True labels. |
required |
y_score |
Collection
|
Predicted scores or probabilities. |
required |
name |
str
|
Name of the plot. Defaults to an empty string. |
''
|
Returns:
Name | Type | Description |
---|---|---|
RocCurveDisplay |
RocCurveDisplay
|
A display object for the ROC curve plot. |
Source code in aimet_ml/metrics/plots.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|
reports
¶
add_metric_to_report(cls_report, metric_name, label_names, metric_values)
¶
Adds metric values to a classification report.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cls_report |
dict
|
The classification report as a dictionary. |
required |
metric_name |
str
|
The name of the metric to add. |
required |
label_names |
Collection[str]
|
Collection of label names. |
required |
metric_values |
Collection[Union[float, int]]
|
Collection of metric values corresponding to label_names. |
required |
Raises:
Type | Description |
---|---|
AssertionError
|
If the lengths of label_names and metric_values do not match. |
Source code in aimet_ml/metrics/reports.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
|
flatten_dict(d, prefix='')
¶
Recursively flattens a nested dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d |
dict
|
The input dictionary to flatten. |
required |
prefix |
str
|
The prefix to be added to flattened keys. Defaults to "". |
''
|
Returns:
Name | Type | Description |
---|---|---|
dict |
Dict[str, Any]
|
A flattened dictionary. |
Source code in aimet_ml/metrics/reports.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
|