Class FixedBucketSampler

java.lang.Object
ai.djl.basicdataset.utils.FixedBucketSampler
All Implemented Interfaces:
ai.djl.training.dataset.Sampler

public class FixedBucketSampler extends Object implements ai.djl.training.dataset.Sampler
FixedBucketSampler is a Sampler to be used with TextDataset, and PaddingStackBatchifier. It groups text data of same length, and samples them together so that the amount of padding required is minimised. It also makes sure that the sampling is random across epochs.
  • Constructor Details

    • FixedBucketSampler

      public FixedBucketSampler(int batchSize, int numBuckets, boolean shuffle)
      Constructs a new instance of FixedBucketSampler with the given number of buckets, and the given batch size.
      Parameters:
      batchSize - the batch size
      numBuckets - the number of buckets
      shuffle - whether to shuffle data randomly while sampling
    • FixedBucketSampler

      public FixedBucketSampler(int batchSize, int numBuckets)
      Constructs a new instance of FixedBucketSampler with the given number of buckets, and the given batch size.
      Parameters:
      batchSize - the batch size
      numBuckets - the number of buckets
    • FixedBucketSampler

      public FixedBucketSampler(int batchSize)
      Constructs a new instance of FixedBucketSampler with the given number of buckets, and the given batch size.
      Parameters:
      batchSize - the batch size
  • Method Details

    • sample

      public Iterator<List<Long>> sample(ai.djl.training.dataset.RandomAccessDataset dataset)
      Specified by:
      sample in interface ai.djl.training.dataset.Sampler
    • getBatchSize

      public int getBatchSize()
      Specified by:
      getBatchSize in interface ai.djl.training.dataset.Sampler