Class StanfordQuestionAnsweringDataset

java.lang.Object
ai.djl.training.dataset.RandomAccessDataset
ai.djl.basicdataset.nlp.TextDataset
ai.djl.basicdataset.nlp.StanfordQuestionAnsweringDataset
All Implemented Interfaces:
ai.djl.training.dataset.Dataset, ai.djl.training.dataset.RawDataset<Object>

public class StanfordQuestionAnsweringDataset extends TextDataset implements ai.djl.training.dataset.RawDataset<Object>
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.
See Also:
  • Constructor Details

  • Method Details

    • builder

      Creates a new builder to build a StanfordQuestionAnsweringDataset.
      Returns:
      a new builder
    • prepare

      public void prepare(ai.djl.util.Progress progress) throws IOException, ai.djl.modality.nlp.embedding.EmbeddingException
      Prepares the dataset for use with tracked progress. In this method the JSON file will be parsed. The question, context, title will be added to sourceTextData and the answers will be added to targetTextData. Both of them will then be preprocessed.
      Specified by:
      prepare in interface ai.djl.training.dataset.Dataset
      Parameters:
      progress - the progress tracker
      Throws:
      IOException - for various exceptions depending on the dataset
      ai.djl.modality.nlp.embedding.EmbeddingException - if there are exceptions during the embedding process
    • get

      public ai.djl.training.dataset.Record get(ai.djl.ndarray.NDManager manager, long index)
      Gets the Record for the given index from the dataset.
      Specified by:
      get in class ai.djl.training.dataset.RandomAccessDataset
      Parameters:
      manager - the manager used to create the arrays
      index - the index of the requested data item
      Returns:
      a Record that contains the data and label of the requested data item. The data NDList contains three NDArrays representing the embedded title, context and question, which are named accordingly. The label NDList contains multiple NDArrays corresponding to each embedded answer.
    • availableSize

      protected long availableSize()
      Returns the number of records available to be read in this Dataset. In this implementation, the actual size of available records are the size of questionInfoList.
      Specified by:
      availableSize in class ai.djl.training.dataset.RandomAccessDataset
      Returns:
      the number of records available to be read in this Dataset
    • getData

      public Object getData() throws IOException
      Get data from the SQuAD dataset. This method will directly return the whole dataset as an object
      Specified by:
      getData in interface ai.djl.training.dataset.RawDataset<Object>
      Returns:
      an object of Object class in the structure of JSON, e.g. Map<String, List<Map<...>>>
      Throws:
      IOException
    • preprocess

      protected void preprocess(List<String> newTextData, boolean source) throws ai.djl.modality.nlp.embedding.EmbeddingException
      Performs pre-processing steps on text data such as tokenising, applying TextProcessors, creating vocabulary, and word embeddings. Since the record number in this dataset is not equivalent to the length of sourceTextData and targetTextData, the limit should be processed.
      Overrides:
      preprocess in class TextDataset
      Parameters:
      newTextData - list of all unprocessed sentences in the dataset
      source - whether the text data provided is source or target
      Throws:
      ai.djl.modality.nlp.embedding.EmbeddingException - if there is an error while embedding input