machine learning – PyTorch backpropagation after modifing outputs

I have the task to answer a question given a related context where the answer is in. The data is structured like this:

    "context": "some long interesting text",
    "question": "number showing how many atoms or molecules of a given element or compound are involved in a chemical",
    "answers": {
        "answer_start": [
            1800,
            3495,
            1576,
            3359
        ],
        "text": [
            "coefficient",
            "coefficient",
            "coefficient",
            "coefficient"
        ]
    }

The model needs to predict the word “coefficient” in the end. My approach is to use a QA-model from huggingface.This model predicts the position in the context where the answer is found, like the answer_start values from the dataset.
The problem I faced was that it predicts the position based on the tokens it got fed, so it would be e.g. token 300 but character 1000 (like in the dataset).

To overcome this, I transform token to character position, but this destroys the link to the output and thus I can’t compute gradients.
This is the relevant code I use:

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
model.train()
for epoch in range(1):
    for batch in loader:
        optimizer.zero_grad()
        inputs, answers, questions = batch
        inputs = {key: value.squeeze(1).to(device) for key, value in inputs.items()}
        outputs = model(**inputs)

        start_scores = outputs.start_logits
        end_scores = outputs.end_logits

        batch_size = start_scores.size(0)
        y_true_start = torch.empty(batch_size).type(torch.FloatTensor).to(device)
        y_true_end = torch.empty(batch_size).type(torch.FloatTensor).to(device)
        y_pred_start = torch.empty(batch_size).type(torch.FloatTensor).to(device)
        y_pred_end = torch.empty(batch_size).type(torch.FloatTensor).to(device)

        for i in range(batch_size):
            answer_start = torch.argmax(start_scores[i])
            answer_end = torch.argmax(end_scores[i])
            question = questions[i]

            answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][i][answer_start:answer_end+1]))
            print(answer)

            #convert answer_start and answer_end from token level to char level
            answer_until_start = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][:answer_start])).replace("<s>", "").replace("</s>", "")
            escaped_chars = answer_until_start.count("\"")
            escaped_chars += answer_until_start.count("\\")
            char_start = len(answer_until_start) - len(question) + escaped_chars + 1
            char_end = char_start + len(answer.strip())

            #find closest start and end out of multiple possible answers
            start = answers["answer_start"][i]
            end = answers["answer_end"][i]
            closest_match_idx = (start - char_start).abs().argmin()
            start_true, end_true = start[closest_match_idx], end[closest_match_idx]
            
            y_true_start[i] = start_true
            y_true_end[i] = end_true
            y_pred_start[i] = char_start
            y_pred_end[i] = char_end

        loss1 = torch.nn.MSELoss()(y_pred_start, y_true_start)
        loss2 = torch.nn.MSELoss()(y_pred_end, y_true_end)
        loss = loss1 + loss2
        loss.backward(retain_graph=True)

What could I change to somehow keep the relation to the original output of the model in order to be able to compute gradients?

Read more here: Source link