-
class
dynn.data.batching.parallel_sequences_batching.
SequencePairsBatches
(src_data, tgt_data, src_dictionary, tgt_dictionary=None, labels=None, max_samples=32, max_tokens=99999999, strict_token_limit=False, shuffle=True, group_by_length='source', src_left_aligned=True, tgt_left_aligned=True)¶ Bases:
object
Wraps two lists of sequences as a batch iterator.
This is useful for sequence-to-sequence problems or sentence pairs classification (entailment, paraphrase detection…). Following seq2seq conventions the first sequence is referred to as the “source” and the second as the “target”.
You can then iterate over this object and get tuples of
src_batch, tgt_batch
ready for use in your computation graph.Example:
# Dictionary dic = dynn.data.dictionary.Dictionary(symbols="abcde".split()) # 1000 source sequences of various lengths up to 10 src_data = [np.random.randint(len(dic), size=np.random.randint(10)) for _ in range(1000)] # 1000 target sequences of various lengths up to 10 tgt_data = [np.random.randint(len(dic), size=np.random.randint(10)) for _ in range(1000)] # Iterator with at most 20 samples or 50 tokens per batch batched_dataset = SequencePairsBatches( src_data, tgt_data, max_samples=20 ) # Training loop for x, y in batched_dataset: # x and y are SequenceBatch objects
Parameters: - src_data (list) – List of source sequences (list of int iterables)
- tgt_data (list) – List of target sequences (list of int iterables)
- src_dictionary (Dictionary) – Source dictionary
- tgt_dictionary (Dictionary) – Target dictionary
- max_samples (int, optional) – Maximum number of samples per batch (one sample is a pair of sentences)
- max_tokens (int, optional) – Maximum number of total tokens per batch (source + target tokens)
- strict_token_limit (bool, optional) – Padding tokens will count towards
the
max_tokens
limit - shuffle (bool, optional) – Shuffle the dataset whenever starting a new
iteration (default:
True
) - group_by_length (str, optional) – Group sequences by length. One of
"source"
or"target"
. This minimizes the number of padding tokens. The batches are not strictly IID though. - src_left_aligned (bool, optional) – Align the source sequences to the left
- tgt_left_aligned (bool, optional) – Align the target sequences to the left
-
__getitem__
(index)¶ Returns the
index
th sampleThe result is a tuple
src_batch, tgt_batch
where each is abatch_data
is aSequenceBatch
objectParameters: index (int, slice) – Index or slice Returns: src_batch, tgt_batch
Return type: tuple
-
__init__
(src_data, tgt_data, src_dictionary, tgt_dictionary=None, labels=None, max_samples=32, max_tokens=99999999, strict_token_limit=False, shuffle=True, group_by_length='source', src_left_aligned=True, tgt_left_aligned=True)¶ Initialize self. See help(type(self)) for accurate signature.
-
__len__
()¶ This returns the number of batches in the dataset (not the total number of samples)
Returns: - Number of batches in the dataset
ceil(len(data)/batch_size)
Return type: int
-
__weakref__
¶ list of weak references to the object (if defined)
-
just_passed_multiple
(batch_number)¶ Checks whether the current number of batches processed has just passed a multiple of
batch_number
.For example you can use this to report at regular interval (eg. every 10 batches)
Parameters: batch_number (int) – [description] Returns: True
if \(\fraccurrent_batch\)Return type: bool
-
percentage_done
()¶ What percent of the data has been covered in the current epoch
-
reset
()¶ Reset the iterator and shuffle the dataset if applicable