Conversation
|
Thanks for opening your first pull request in this repository! Someone will review it when they have a chance. In the mean time, please be sure that you've handled the following things, to make the review process quicker and easier:
Thank you again for your contributions! 👍 |
|
Hey @mrityunjay-tripathi, Thanks for opening this PR. If possible could you create a models folder and place the transformer inside. As many models will be added to the repo it would be great if we had a models folder rather than having a lot of models in main folder. So it would be models/models/transformer. |
|
Sure @kartikdutt18! Actually I was also thinking about why it was not that way. But no worries. I will make it that way now 👍 |
Awesome, The reason for that it is, it's part of Restructuring - 3. All existing models will be replaced with something similar to what you have implemented i.e. in a class so that user could use pre trained models. |
|
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
0c08f69 to
37e2414
Compare
|
@lozhnikov I've made the changes as you suggested. Can you please take a look? It's mostly done and once the required layers are merged we can test this. |
|
Sure, I'll review the PR today (in the evening). |
|
I was trying to test this locally and I got following error-- error: matrix multiplication: incompatible matrix dimensions: 0x0 and 16x10
unknown location(0): fatal error: in "FFNModelsTests/TransformerEncoderTest": std::logic_error: matrix multiplication: incompatible matrix dimensions: 0x0 and 16x10I feel there is some problem with Reset. The weights and biases are not allocated memory. But I can't find why. |
I'll look into it in the morning. Upd: I'd use GDB in order to find the actual place where it happens. |
|
@mrityunjay-tripathi I think I get it. Looks like my comment #16 (comment) was wrong. We need to pass |
Yeah. Got it. Thanks :) |
Co-authored-by: Mikhail Lozhnikov <lozhnikovma@gmail.com>
5e572fa to
fbdd4ff
Compare
|
@mrityunjay-tripathi Is this PR ready? Can I review this? |
|
Yes. This is ready for review. |
lozhnikov
left a comment
There was a problem hiding this comment.
Sorry for the slow response. The beginning of the term was hard. Now things settled a bit. I found a tiny flaw in the decoder implementation. I'll suggest the fix in the evening.
models/transformer/decoder.hpp
Outdated
| MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen, | ||
| tgtSeqLen, | ||
| dModel, | ||
| numHeads); |
There was a problem hiding this comment.
Shouldn't the second argument be equal to srcSeqLen?
| // This layer concatenates the output of the bottom decoder block (query) | ||
| // and the output of the encoder (key, value). | ||
| Concat<>* encDecAttnInput = new Concat<>(true); | ||
| encDecAttnInput->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1); |
There was a problem hiding this comment.
I think this is incorrect. It's the decoder bottom input. But the encoder-decoder attention block should receive the output of the decoder bottom.
| encDecAttnInput->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1); | |
| encDecAttnInput->Add(decoderBlockBottom); |
| // Residual connection. | ||
| AddMerge<>* residualAdd2 = new AddMerge<>(true); | ||
| residualAdd2->Add(encoderDecoderAttention); | ||
| residualAdd2->Add(decoderBlockBottom); |
There was a problem hiding this comment.
You can't pass the same block twice (see the comment to encDecAttnInput). Looks like we need to change the model a bit. I have to go now. I'll come up with the idea in the evening.
|
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
Hello everyone! I've implemented the transformer encoder and decoder. Though there are other dependencies for this PR, I made this PR to get some insights and opinions. Some things still remaining regarding this PR: