Skip to content

model_selection

splits

get_splitter(stratify_cols=None, group_cols=None, n_splits=5, random_state=1414)

Get a cross-validation splitter based on input parameters.

Parameters:

Name Type Description Default
stratify_cols Collection[str]

Column names for stratification. Defaults to None.

None
group_cols Collection[str]

Column names for grouping. Defaults to None.

None
n_splits int

Number of splits in the cross-validation. Defaults to 5.

5
random_state int

Seed for random number generator. Defaults to 1414.

1414

Returns:

Name Type Description
BaseCrossValidator BaseCrossValidator

A cross-validation splitter based on the input parameters.

Source code in aimet_ml/model_selection/splits.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def get_splitter(
    stratify_cols: Optional[Collection[str]] = None,
    group_cols: Optional[Collection[str]] = None,
    n_splits: int = 5,
    random_state: int = 1414,
) -> BaseCrossValidator:
    """
    Get a cross-validation splitter based on input parameters.

    Args:
        stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
        group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
        n_splits (int, optional): Number of splits in the cross-validation. Defaults to 5.
        random_state (int): Seed for random number generator. Defaults to 1414.

    Returns:
        BaseCrossValidator: A cross-validation splitter based on the input parameters.
    """
    if n_splits <= 1:
        raise ValueError("n_splits must be greater than 1")

    stratify_cols = stratify_cols if stratify_cols else None
    group_cols = group_cols if group_cols else None

    unique_stratify_cols = set(stratify_cols) if stratify_cols else set()
    unique_group_cols = set(group_cols) if group_cols else set()

    if unique_stratify_cols.intersection(unique_group_cols):
        raise ValueError("group_cols and stratify_cols must be disjoint")

    if (stratify_cols is not None) and (group_cols is not None):
        return StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    if stratify_cols is not None:
        return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    if group_cols is not None:
        return GroupKFold(n_splits=n_splits)

    return KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

join_cols(df, cols, sep='_')

Concatenate the specified columns of a DataFrame with a separator.

Parameters:

Name Type Description Default
df DataFrame

The DataFrame to operate on.

required
cols Collection[str]

Column names to concatenate.

required
sep str

The separator to use between the column values. Defaults to "_".

'_'

Returns:

Type Description
Series

pd.Series: A Series containing the concatenated values.

Source code in aimet_ml/model_selection/splits.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
def join_cols(df: pd.DataFrame, cols: Collection[str], sep: str = "_") -> pd.Series:
    """
    Concatenate the specified columns of a DataFrame with a separator.

    Args:
        df (pd.DataFrame): The DataFrame to operate on.
        cols (Collection[str]): Column names to concatenate.
        sep (str, optional): The separator to use between the column values. Defaults to "_".

    Returns:
        pd.Series: A Series containing the concatenated values.
    """
    if len(cols) == 0:
        raise ValueError("At least a column name is required, got empthy")
    return df[cols].apply(lambda row: sep.join(row.astype(str)), axis=1)

split_dataset(dataset_df, val_fraction=0.1, test_n_splits=5, stratify_cols=None, group_cols=None, train_split_name_format='train_fold_{}', val_split_name_format='val_fold_{}', test_split_name_format='test_fold_{}', random_seed=1414)

Split a dataset into k-fold cross-validation sets with stratification and grouping.

The dataset will be split into k-fold cross-validation sets, each containing development and test sets. For each fold, the development set will be further split into training and validation sets. The final data splits include k test sets, k training sets, and k validation sets.

Parameters:

Name Type Description Default
dataset_df DataFrame

The input DataFrame to be split.

required
val_fraction Union[float, int]

The fraction of data to be used for validation. If a float is given, it's rounded to the nearest fraction. If an integer (n) is given, the fraction is calculated as 1/n. Defaults to 0.1.

0.1
test_n_splits int

Number of cross-validation splits. Defaults to 5.

5
stratify_cols Collection[str]

Column names for stratification. Defaults to None.

None
group_cols Collection[str]

Column names for grouping. Defaults to None.

None
train_split_name_format str

Format for naming training splits. Defaults to "train_fold_{}".

'train_fold_{}'
val_split_name_format str

Format for naming validation splits. Defaults to "val_fold_{}".

'val_fold_{}'
test_split_name_format str

Format for naming validation splits. Defaults to "test_fold_{}".

'test_fold_{}'
random_seed int

Random seed for reproducibility. Defaults to 1414.

1414

Returns:

Type Description
Dict[str, DataFrame]

Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.

Source code in aimet_ml/model_selection/splits.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def split_dataset(
    dataset_df: pd.DataFrame,
    val_fraction: Union[float, int] = 0.1,
    test_n_splits: int = 5,
    stratify_cols: Optional[Collection[str]] = None,
    group_cols: Optional[Collection[str]] = None,
    train_split_name_format: str = "train_fold_{}",
    val_split_name_format: str = "val_fold_{}",
    test_split_name_format: str = "test_fold_{}",
    random_seed: int = 1414,
) -> Dict[str, pd.DataFrame]:
    """
    Split a dataset into k-fold cross-validation sets with stratification and grouping.

    The dataset will be split into k-fold cross-validation sets, each containing development and test sets.
    For each fold, the development set will be further split into training and validation sets.
    The final data splits include k test sets, k training sets, and k validation sets.

    Args:
        dataset_df (pd.DataFrame): The input DataFrame to be split.
        val_fraction (Union[float, int], optional): The fraction of data to be used for validation.
                                                     If a float is given, it's rounded to the nearest fraction.
                                                     If an integer (n) is given, the fraction is calculated as 1/n.
                                                     Defaults to 0.1.
        test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
        stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
        group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
        train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}".
        val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}".
        test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}".
        random_seed (int, optional): Random seed for reproducibility. Defaults to 1414.

    Returns:
        Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.
    """
    if test_n_splits <= 1:
        raise ValueError("test_n_splits must be greater than 1")

    data_splits = dict()

    # cross-validation split
    k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed)

    stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None
    groups = join_cols(dataset_df, group_cols) if group_cols else None

    for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)):
        k = n + 1
        data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True)

        # split into training and validation sets
        dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True)
        train_dataset_df, val_dataset_df = stratified_group_split(
            dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed
        )
        data_splits[train_split_name_format.format(k)] = train_dataset_df
        data_splits[val_split_name_format.format(k)] = val_dataset_df

    return data_splits

split_dataset_single_test(dataset_df, test_fraction=0.2, val_n_splits=5, stratify_cols=None, group_cols=None, test_split_name='test', dev_split_name='dev', train_split_name_format='train_fold_{}', val_split_name_format='val_fold_{}', random_seed=1414)

Split a dataset into development, test, and cross-validation sets with stratification and grouping.

The dataset will be split into a development set and a test set. The development set will then be further split into k-fold cross-validation sets, each containing its own training and validation sets. The final data splits include a test set, k training sets, and k validation sets.

Parameters:

Name Type Description Default
dataset_df DataFrame

The input DataFrame to be split.

required
test_fraction Union[float, int]

The fraction of data to be used for testing. If a float is given, it's rounded to the nearest fraction. If an integer (n) is given, the fraction is calculated as 1/n. Defaults to 0.2.

0.2
val_n_splits int

Number of cross-validation splits. Defaults to 5.

5
stratify_cols Collection[str]

Column names for stratification. Defaults to None.

None
group_cols Collection[str]

Column names for grouping. Defaults to None.

None
test_split_name str

Name for the test split. Defaults to "test".

'test'
dev_split_name str

Name for the development split. Defaults to "dev".

'dev'
train_split_name_format str

Format for naming training splits. Defaults to "train_fold_{}".

'train_fold_{}'
val_split_name_format str

Format for naming validation splits. Defaults to "val_fold_{}".

'val_fold_{}'
random_seed int

Random seed for reproducibility. Defaults to 1414.

1414

Returns:

Type Description
Dict[str, DataFrame]

Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.

Source code in aimet_ml/model_selection/splits.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def split_dataset_single_test(
    dataset_df: pd.DataFrame,
    test_fraction: Union[float, int] = 0.2,
    val_n_splits: int = 5,
    stratify_cols: Optional[Collection[str]] = None,
    group_cols: Optional[Collection[str]] = None,
    test_split_name: str = "test",
    dev_split_name: str = "dev",
    train_split_name_format: str = "train_fold_{}",
    val_split_name_format: str = "val_fold_{}",
    random_seed: int = 1414,
) -> Dict[str, pd.DataFrame]:
    """
    Split a dataset into development, test, and cross-validation sets with stratification and grouping.

    The dataset will be split into a development set and a test set. The development set will then be further
    split into k-fold cross-validation sets, each containing its own training and validation sets.
    The final data splits include a test set, k training sets, and k validation sets.

    Args:
        dataset_df (pd.DataFrame): The input DataFrame to be split.
        test_fraction (Union[float, int], optional): The fraction of data to be used for testing.
                                                     If a float is given, it's rounded to the nearest fraction.
                                                     If an integer (n) is given, the fraction is calculated as 1/n.
                                                     Defaults to 0.2.
        val_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
        stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
        group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
        test_split_name (str, optional): Name for the test split. Defaults to "test".
        dev_split_name (str, optional): Name for the development split. Defaults to "dev".
        train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}".
        val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}".
        random_seed (int, optional): Random seed for reproducibility. Defaults to 1414.

    Returns:
        Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.
    """
    if val_n_splits <= 1:
        raise ValueError("val_n_splits must be greater than 1")

    data_splits = dict()

    # split into dev and test datasets
    dev_dataset_df, test_dataset_df = stratified_group_split(
        dataset_df, test_fraction, stratify_cols, group_cols, random_seed
    )
    data_splits[dev_split_name] = dev_dataset_df
    data_splits[test_split_name] = test_dataset_df

    # cross-validation split
    k_fold_splitter = get_splitter(stratify_cols, group_cols, val_n_splits, random_seed)

    dev_stratify = join_cols(dev_dataset_df, stratify_cols) if stratify_cols else None
    dev_groups = join_cols(dev_dataset_df, group_cols) if group_cols else None

    for n, (train_rows, val_rows) in enumerate(
        k_fold_splitter.split(X=dev_dataset_df, y=dev_stratify, groups=dev_groups)
    ):
        k = n + 1
        data_splits[train_split_name_format.format(k)] = dev_dataset_df.iloc[train_rows].reset_index(drop=True)
        data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True)

    return data_splits

stratified_group_split(dataset_df, test_fraction=0.2, stratify_cols=None, group_cols=None, random_seed=1414)

Split a dataset into development and test sets with stratification and grouping.

Parameters:

Name Type Description Default
dataset_df DataFrame

The input DataFrame to be split.

required
test_fraction Union[float, int]

The fraction of data to be used for testing. If a float (0, 1) is given, it's rounded to the nearest fraction. If an integer (n > 1) is given, the fraction is calculated as 1/n. Defaults to 0.2.

0.2
stratify_cols Collection[str]

Column names for stratification. Defaults to None.

None
group_cols Collection[str]

Column names for grouping. Defaults to None.

None
random_seed int

Random seed for reproducibility. Defaults to 1414.

1414

Returns:

Type Description
Tuple[DataFrame, DataFrame]

Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the development and test DataFrames.

Source code in aimet_ml/model_selection/splits.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def stratified_group_split(
    dataset_df: pd.DataFrame,
    test_fraction: Union[float, int] = 0.2,
    stratify_cols: Optional[Collection[str]] = None,
    group_cols: Optional[Collection[str]] = None,
    random_seed: int = 1414,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Split a dataset into development and test sets with stratification and grouping.

    Args:
        dataset_df (pd.DataFrame): The input DataFrame to be split.
        test_fraction (Union[float, int], optional): The fraction of data to be used for testing.
                                                     If a float (0, 1) is given, it's rounded to the nearest fraction.
                                                     If an integer (n > 1) is given, the fraction is calculated as 1/n.
                                                     Defaults to 0.2.
        stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
        group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
        random_seed (int, optional): Random seed for reproducibility. Defaults to 1414.

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the development and test DataFrames.
    """
    if test_fraction <= 0:
        raise ValueError("test_fraction must be greater than 0")

    if test_fraction == 1:
        raise ValueError("test_fraction must not equal to 1")

    if isinstance(test_fraction, float) and (test_fraction > 1):
        raise ValueError("test_fraction provided as float must be less than 1")

    if isinstance(test_fraction, int):
        test_fraction = 1 / test_fraction

    split_fraction = min(test_fraction, 1 - test_fraction)
    n_splits = round(1 / split_fraction)
    splitter = get_splitter(stratify_cols, group_cols, n_splits, random_seed)

    stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None
    groups = join_cols(dataset_df, group_cols) if group_cols else None

    lowest_diff = float('inf')
    best_dev_rows, best_test_rows = None, None
    for dev_rows, test_rows in splitter.split(X=dataset_df, y=stratify, groups=groups):
        fraction = len(test_rows) / len(dataset_df)
        diff = abs(split_fraction - fraction)
        if diff < lowest_diff:
            lowest_diff = diff
            best_dev_rows = dev_rows
            best_test_rows = test_rows

    if test_fraction > 0.5:
        best_dev_rows, best_test_rows = best_test_rows, best_dev_rows

    dev_dataset_df = dataset_df.iloc[best_dev_rows].reset_index(drop=True)
    test_dataset_df = dataset_df.iloc[best_test_rows].reset_index(drop=True)

    return dev_dataset_df, test_dataset_df