Skip to content

Possibility of optimisation in Focal Loss implementation #1235

@vedantdalimkar

Description

@vedantdalimkar

I was training a multiclass model using focal loss provided by SMP. However, it was proving to be a bottleneck in training as switching torch's CrossEntropyLoss with smp's FocalLoss increased training time by a factor of 6x.

I think this may be caused by the suboptimal multiclass mode implementation where the function iterates over each class and calls the focal_loss_with_logits function seperately. This operation can be vectorised by calling focal_loss_with_logits once for all classes where target would be the one-hot encoding of the segmentation mask.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions