features
textual
¶
transformers
¶
TransformerFeatureExtractor
¶
Extracts features from input texts using transformer embeddings.
Source code in aimet_ml/features/textual/transformers.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
|
__init__(model_name, num_emb_layers=4, max_length=512, device='cuda:0')
¶
Initializes the TransformerFeatureExtractor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str
|
The name or path of the pre-trained transformer model. |
required |
num_emb_layers |
int
|
Number of layers to use for feature extraction. Default is 4. |
4
|
max_length |
int
|
Maximum length of input text for tokenization. Default is 512. |
512
|
device |
str or device
|
Device to use for computation ('cuda:0', 'cpu', etc.). Default is 'cuda:0' if available, else 'cpu'. |
'cuda:0'
|
Source code in aimet_ml/features/textual/transformers.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
extract_features(texts)
¶
Extracts features from input texts using transformer embeddings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
texts |
str or list
|
Input text or list of texts for feature extraction. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Extracted features for input texts. |
Source code in aimet_ml/features/textual/transformers.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
tokenize(text)
¶
Tokenizes input text using the transformer's tokenizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
str
|
Input text to be tokenized. |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
Dictionary containing tokenized input with attention mask. |
Source code in aimet_ml/features/textual/transformers.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
|