The mask in jvp might not match the pytree structure of grad_y.
Unfortunately, I don't have an open-source repro at the moment, but these lines assume pytrees of a similar structure, however the mask is a leaf, but the gradients are a pytree
|
new_masks = broadcast_mask_to_jacobian(out_mask, grad_y) |
|
result_mask = broadcast_mask_to_jacobian(result_mask, grad_y) |
Is that correct?
I believe you recently addressed a similar issue here:
5150e33#diff-21e634aa62155f577c8e87e1b851189b4791db79bdb2593cc957ca86e8cde5ccL328
The mask in jvp might not match the pytree structure of
grad_y.Unfortunately, I don't have an open-source repro at the moment, but these lines assume pytrees of a similar structure, however the mask is a leaf, but the gradients are a pytree
folx/folx/jvp.py
Line 183 in d3bf210
folx/folx/jvp.py
Line 248 in d3bf210
Is that correct?
I believe you recently addressed a similar issue here:
5150e33#diff-21e634aa62155f577c8e87e1b851189b4791db79bdb2593cc957ca86e8cde5ccL328