Skip to content
This repository was archived by the owner on Nov 8, 2022. It is now read-only.
This repository was archived by the owner on Nov 8, 2022. It is now read-only.

improvement: distillation for TransformerSequenceClassifier models for GLUE tasks #218

Description

@rmovva

Hi,

I'm wondering if it would be easy to add support for knowledge distillation for the Transformers on GLUE tasks (i.e. the TransformerSequenceClassifier module).

I see that the distillation loss has been implemented, and it's an option for the NeuralTagger which uses the TransformerTokenClassifier. Would it be easy to add distillation support for the GLUE models?

Here's how I was envisioning implementing it, modeling off of the distillation implementation for the tagger models:

  1. TransformerSequenceClassifier’s train calls the base transformer model’s _train(). I would need to add a distiller argument to this function. This function would then handle distillation by loading in the teacher and the relevant arguments, just like the NeuralTagger
  2. in procedures/transformers/glue.py, i would need to add a do_kd_training function that adds distilation args. This function would need to create a teacher model from these args (loading in weights from the passed-in path), create a TeacherStudentDistill instance, and pass in this object as the distiller argument into the new _train() function.

Does this seem about right? Are there any roadblocks you'd envision / why wasn't distillation implemented for the sequence classifier models to begin with?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions