-
Notifications
You must be signed in to change notification settings - Fork 105
coord_check for model that returns loss function directly #68
Copy link
Copy link
Open
Description
Some transformers (like x-transformers) take in a sequence of length (seq_len+1), then split it into input=x[:-1] and target=x[1:], and calculate the loss directly in forward(). This is efficient because the input and targets overlap. It means that forward() returns the loss, rather than the targets.
It would be nice if coord_check had an option that supported this usecase, where forward() returns the loss directly. Like adding loss_from_forward to the function signatures, and inserting this:
elif loss_from_forward:
if cuda:
batch = batch.cuda()
loss = model(batch)
at
Line 317 in 1981497
| else: |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels