Skip to content

Fixes ctc loss input checking#50

Open
gwenniger wants to merge 4 commits intoSeanNaren:pytorch_bindingsfrom
gwenniger:fixes_ctc_loss_input_checking
Open

Fixes ctc loss input checking#50
gwenniger wants to merge 4 commits intoSeanNaren:pytorch_bindingsfrom
gwenniger:fixes_ctc_loss_input_checking

Conversation

@gwenniger
Copy link
Copy Markdown

Dear Sean,
I started using your warp_ctc interface code and I enjoy it a lot. It took me some time however to find out how exactly to use the interface. Eventually I was able to reproduce the examples from https://github.com/baidu-research/warp-ctc/blob/master/torch_binding/TUTORIAL.md
I thought it good to add some more explanatory checks to the input of the CTCLoss forward function, which is the main entrypoint really as per my understanding. I added checks:

Added additional and more clear checks of the correctness of the input
to the CTCLoss forward function:

1. Check that the labels input has dimensionality 1. There was already
   an assertion for this, but now replaced this with an actual explanatory
   runtime error that explains that a 1-dimensional label sequence tensor
   containing the label sequences for all tensors concatenated is expected.

2. Check that the sum of the label_length in label_lengths is equal to the
   length of the labels tensor. This must be the case, since label_length
   essentially how labels is to be segmented to retrieve the original label
   sequences for the different examples, that were concatenated for the sake
   of computation.

I think these checks add to the easy of use of the function and will provide instructions to the user that help overcoming the problem when the interface is not correctly used. Right now in certain cases, for example when the second check is not satisfied, the program still runs, but the output is not properly interpretable.
See https://discuss.pytorch.org/t/ctcloss-dont-work-in-pytorch/13859
"
Can you post a minimal gist or so to reproduce?
(I.e. precompute outputs and target and just have your ctc application.)
It works for me but acts funny on invalid inputs etc.

Best regards Thomas
"
I think checking the input for non-valid combinations of tensors will be a good way to improve the discoverability of the correct use of the code, and solve some of the problems with wrong output for invalid inputs.

gwenniger added 4 commits May 15, 2018 21:12
in  pytorch_binding/warpctc_pytorch/__init__.py.
This method was using the trick
"ctx.grads = Variable(grads, volatile=True)"
to avoid computation of gradients, but this is now deprecated.
The fix, for pytorch 0.4.0 is to use " with torch.no_grad():"
around the thins that should not compute gradient information for
autograd instead.

See also:
https://discuss.pytorch.org/t/torch-no-grad/12296
https://pytorch.org/2018/04/22/0_4_0-migration-guide.htmlw

	modified:   pytorch_binding/warpctc_pytorch/__init__.py
to the CTCLoss forward function:

1. Check that the labels input has dimensionality 1. There was already
   an assertion for this, but now replaced this with an actual explanatory
   runtime error that explains that a 1-dimensional label sequence tensor
   containing the label sequences for all tensors concatenated is expected.

2. Check that the sum of the label_length in label_lengths is equal to the
   length of the labels tensor. This must be the case, since label_length
   essentially how labels is to be segmented to retrieve the original label
   sequences for the different examples, that were concatenated for the sake
   of computation.

	modified:   pytorch_binding/warpctc_pytorch/__init__.py
removed added debugging output.

	modified:   pytorch_binding/warpctc_pytorch/__init__.py
pytorch 0.4.0, the version that we need.

	modified:   pytorch_binding/warpctc_pytorch/__init__.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant