Language Selection

Get healthy now with MedBeds!
Click here to book your session

Protect your whole family with Orgo-Life® Quantum MedBed Energy Technology® devices.

Advertising by Adpathway

         

 Advertising by Adpathway

Fine-Tuning a BERT Model

5 months ago 67

PROTECT YOUR DNA WITH QUANTUM TECHNOLOGY

Orgo-Life the new way to the future

  Advertising by Adpathway

BERT is a foundational NLP model trained to understand language, but it may not work for any specific task out of the box. However, you can build upon BERT by adding appropriate model heads and training it for a specific task. This process is called fine-tuning. In this article, you will learn how to fine-tune a BERT model for several NLP tasks.

Let’s get started.

Overview

This article is divided into two parts; they are:

  • Fine-tuning a BERT Model for GLUE Tasks
  • Fine-tuning a BERT Model for SQuAD Tasks

Fine-tuning a BERT Model for GLUE Tasks

GLUE is a benchmark for evaluating natural language understanding (NLU) tasks. It contains 9 tasks, such as sentiment analysis, paraphrase identification, and text classification. The model learns task-specific behavior from examples. GLUE has a held-out test set to evaluate model performance on each task, with results reported on a public leaderboard.

Let’s take the “sst2” task (sentiment classification) in GLUE as an example.

You can load the dataset using the Hugging Face datasets library:

from datasets import load_dataset

task = "sst2" # sentiment classification

dataset = load_dataset("glue", task)

print("Train size:", len(dataset["train"]))

print("Validation size:", len(dataset["validation"]))

print("Test size:", len(dataset["test"]))

# print one sample

print(dataset["train"][42])

Running this code, the output is:

Train size: 67349

Validation size: 872

Test size: 1821

{'sentence': "as they come , already having been recycled more times than i 'd care to

count ", 'label': 0, 'idx': 42}

The dataset loaded has three splits: train, validation, and test. Each sample in the dataset is a dictionary. The keys that we are interested in are "sentence" and "label". The label is either 0 or 1, representing negative or positive sentiment, respectively.

This dataset cannot be used directly because you need to convert text sentences into token sequences. Moreover, the training loop requires data in batches, so you need to create batches of shuffled and padded sequences. Let’s create a PyTorch DataLoader with a custom collate function:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

...

import torch

def collate(batch: list[dict], tokenizer: tokenizers.Tokenizer, max_len: int):

    """Custom collate function to handle variable-length sequences in dataset."""

    cls_id = tokenizer.token_to_id("[CLS]")

    sep_id = tokenizer.token_to_id("[SEP]")

    pad_id = tokenizer.token_to_id("[PAD]")

    sentences: list[str] = [item["sentence"] for item in batch]

    labels = torch.tensor([item["label"] for item in batch])

    input_ids = []

    for sentence in sentences:

        seq = [cls_id]

        seq.extend(tokenizer.encode(sentence).ids)

        if len(seq) >= max_len:

            seq = seq[:max_len-1]

        seq.append(sep_id)

        num_pad = max_len - len(seq)

        seq.extend([pad_id] * num_pad)

        input_ids.append(seq)

    input_ids = torch.tensor(input_ids, dtype=torch.long)

    return input_ids, labels

batch_size = 16

max_len = 128

tokenizer = tokenizers.Tokenizer.from_file("wikitext-2_wordpiece.json")

collate_fn = functools.partial(collate, tokenizer=tokenizer, max_len=max_len)

train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size,

                                           shuffle=True, collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(dataset["validation"], batch_size=batch_size,

                                         shuffle=False, collate_fn=collate_fn)

To prepare the data, you need to use the tokenizer that you trained in the previous post. The BERT model should also be trained with the same tokenizer.

The collate() function takes a batch of samples as a list of dictionaries. It converts text sentences into token sequences and pads them to the same length. Unlike BERT pre-training, you do not have a pair of sentences, but you still need to use the [CLS] and [SEP] tokens as delimiters in the output sequence. The output of the collate function is a tuple of two tensors: the input IDs as a 2D tensor and the labels as a 1D tensor.

You set up two DataLoader objects: one for the training set and one for the validation set. The training set is shuffled, but the validation set is not.

Next, you need to set up a model for GLUE tasks. Since this is a sentence classification task, you just need to add a linear layer on top of the BERT model to project the hidden state of the [CLS] token to the number of labels. Below is the implementation:

...

class BertForSequenceClassification(nn.Module):

    """BERT model for GLUE tasks."""

    def __init__(self, config: BertConfig, num_labels: int):

        super().__init__()

        self.bert = BertModel(config)

        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self, input_ids: torch.Tensor, pad_id: int = 0) -> torch.Tensor:

        # pooled_output corresponds to the [CLS] token

        token_type_ids = torch.zeros_like(input_ids)

        seq_output, pooled_output = self.bert(input_ids, token_type_ids, pad_id=pad_id)

        logits = self.classifier(pooled_output)

        return logits

You reuse the BertModel and BertConfig classes defined in the previous post. In the BertForSequenceClassification class, you use the foundation BERT model to process the input sequence. The pooled output, which corresponds to the [CLS] token, is then passed to a linear layer to project it to the number of labels. The sequence output, however, is unused. The model’s output is the logits for the classification task. In the case of sentiment classification, this is a vector of two values for each sample.

All fine-tuning of BERT model follows this similar architecture. In fact, you can see the figure below from the BERT paper that you are using the (b) architecture:

Different fine-tuning architecture of BERT. Figure from the BERT paper.

Since you have already trained the foundation BERT model, you can instantiate the model for sequence classification and then load the pretrained weights for the foundation model:

...

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = BertConfig()

model = BertForSequenceClassification(config, num_labels)

model.to(device)

model.bert.load_state_dict(torch.load("bert_model.pth", map_location=device))

Now you can run the training loop. Compared to pre-training, fine-tuning only requires a few epochs. Otherwise, the training loop is quite typical:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

..

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=2e-5)

num_epochs = 3

for epoch in range(num_epochs):

    model.train()

    # Training

    with tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:

        for batch in pbar:

            # get batched data

            input_ids, labels = batch

            input_ids = input_ids.to(device)

            labels = labels.to(device)

            # forward pass

            logits = model(input_ids, torch.zeros_like(input_ids))

            # backward pass

            optimizer.zero_grad()

            loss = loss_fn(logits, labels)

            loss.backward()

            optimizer.step()

            # update progress bar

            pbar.set_postfix(loss=float(loss))

            pbar.update(1)

    # Validation: Keep track of the average loss and accuracy

    model.eval()

    val_loss, num_matches, num_batches, num_samples = 0, 0, 0, 0

    with torch.no_grad():

        for batch in val_loader:

            # get batched data

            input_ids, labels = batch

            input_ids = input_ids.to(device)

            labels = labels.to(device)

            # forward pass on validation data

            logits = model(input_ids)

            # compute loss

            loss = loss_fn(logits, labels)

            val_loss += loss.item()

            num_batches += 1

            # compute accuracy

            predictions = logits.argmax(dim=-1)

            num_matches += (predictions == labels).sum().item()

            num_samples += len(labels)

    avg_loss = val_loss / num_batches

    acc = num_matches / num_samples

    print(f"Validation {epoch+1}/{num_epochs}: acc {acc:.4f}, avg loss {avg_loss:.4f}")

Running this code, you may see:

Epoch 1/3: 100%|██████████████████████████| 4210/4210 [02:14<00:00, 31.37it/s, loss=0.844]

Validation 1/3: acc 0.5092, avg loss 0.7097

Epoch 2/3: 100%|██████████████████████████| 4210/4210 [02:13<00:00, 31.46it/s, loss=0.591]

Validation 2/3: acc 0.5092, avg loss 0.7164

Epoch 3/3: 100%|██████████████████████████| 4210/4210 [02:13<00:00, 31.51it/s, loss=0.699]

Validation 3/3: acc 0.5092, avg loss 0.6932

Since you have both training and validation sets, you train with the training set and then evaluate on the validation set. Be sure to use model.train() and model.eval() to set the model to training or evaluation mode, respectively, since your model uses dropout layers.

That’s all you need to do to fine-tune a BERT model for GLUE tasks. Below is the complete code for sequence classification:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

import dataclasses

import functools

import torch

import torch.nn as nn

import torch.optim as optim

import tqdm

from datasets import load_dataset

from tokenizers import Tokenizer

from torch import Tensor

# BERT config and model defined previously

@dataclasses.dataclass

class BertConfig:

    """Configuration for BERT model."""

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

class BertBlock(nn.Module):

    """One transformer block in BERT."""

    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size),

        )

    def forward(self, x: Tensor, pad_mask: Tensor) -> Tensor:

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

class BertPooler(nn.Module):

    """Pooler layer for BERT to process the [CLS] token output."""

    def __init__(self, hidden_size: int):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

    def forward(self, x: Tensor) -> Tensor:

        x = self.dense(x)

        x = self.activation(x)

        return x

class BertModel(nn.Module):

    """Backbone of BERT model."""

    def __init__(self, config: BertConfig):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

                                            padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

    def forward(self, input_ids: Tensor, token_type_ids: Tensor, pad_id: int = 0,

                ) -> tuple[Tensor, Tensor]:

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

# Define new BERT model for sequence classification

class BertForSequenceClassification(nn.Module):

    """BERT model for GLUE tasks."""

    def __init__(self, config: BertConfig, num_labels: int):

        super().__init__()

        self.bert = BertModel(config)

        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self, input_ids: Tensor, pad_id: int = 0) -> Tensor:

        # pooled_output corresponds to the [CLS] token

        token_type_ids = torch.zeros_like(input_ids)

        seq_output, pooled_output = self.bert(input_ids, token_type_ids, pad_id=pad_id)

        logits = self.classifier(pooled_output)

        return logits

# Load GLUE dataset (e.g., 'sst2' for sentiment classification)

task = "sst2"

dataset = load_dataset("glue", task)

num_labels = 2  # dataset["train"]["label"] is either 0 or 1

# Load the pretrained BERT tokenizer

TOKENIZER_PATH = "wikitext-2_wordpiece.json"

tokenizer = Tokenizer.from_file(TOKENIZER_PATH)

# Setup dataloader for training and validation datasets

def collate(batch: list[dict], tokenizer: Tokenizer, max_len: int) -> tuple[Tensor, Tensor]:

    """Collate variable-length sequences in the dataset."""

    cls_id = tokenizer.token_to_id("[CLS]")

    sep_id = tokenizer.token_to_id("[SEP]")

    pad_id = tokenizer.token_to_id("[PAD]")

    sentences: list[str] = [item["sentence"] for item in batch]

    labels = torch.tensor([item["label"] for item in batch])

    input_ids = []

    for sentence in sentences:

        seq = [cls_id]

        seq.extend(tokenizer.encode(sentence).ids)

        if len(seq) >= max_len:

            seq = seq[:max_len-1]

        seq.append(sep_id)

        num_pad = max_len - len(seq)

        seq.extend([pad_id] * num_pad)

        input_ids.append(seq)

    input_ids = torch.tensor(input_ids, dtype=torch.long)

    return input_ids, labels

batch_size = 16

max_len = 128

collate_fn = functools.partial(collate, tokenizer=tokenizer, max_len=max_len)

train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size,

                                           shuffle=True, collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(dataset["validation"], batch_size=batch_size,

                                         shuffle=False, collate_fn=collate_fn)

# Create classification model with a pretrained foundation BERT model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = BertConfig()

model = BertForSequenceClassification(config, num_labels)

model.to(device)

model.bert.load_state_dict(torch.load("bert_model.pth", map_location=device))

# Training setup

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=2e-5)

num_epochs = 3

for epoch in range(num_epochs):

    model.train()

    # Training

    with tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:

        for batch in pbar:

            # get batched data

            input_ids, labels = batch

            input_ids = input_ids.to(device)

            labels = labels.to(device)

            # forward pass

            logits = model(input_ids, torch.zeros_like(input_ids))

            # backward pass

            optimizer.zero_grad()

            loss = loss_fn(logits, labels)

            loss.backward()

            optimizer.step()

            # update progress bar

            pbar.set_postfix(loss=float(loss))

            pbar.update(1)

    # Validation: Keep track of the average loss and accuracy

    model.eval()

    val_loss, num_matches, num_batches, num_samples = 0, 0, 0, 0

    with torch.no_grad():

        for batch in val_loader:

            # get batched data

            input_ids, labels = batch

            input_ids = input_ids.to(device)

            labels = labels.to(device)

            # forward pass on validation data

            logits = model(input_ids)

            # compute loss

            loss = loss_fn(logits, labels)

            val_loss += loss.item()

            num_batches += 1

            # compute accuracy

            predictions = logits.argmax(dim=-1)

            num_matches += (predictions == labels).sum().item()

            num_samples += len(labels)

    avg_loss = val_loss / num_batches

    acc = num_matches / num_samples

    print(f"Validation {epoch+1}/{num_epochs}: acc {acc:.4f}, avg loss {avg_loss:.4f}")

# Save the fine-tuned model

torch.save(model.state_dict(), f"bert_model_glue_sst2.pth")

Fine-tuning a BERT Model for SQuAD

SQuAD is a question answering dataset. Each sample contains a question and a context paragraph. The answer to the question is a span of words within the context paragraph. This is not a general question answering task since the answer is always a substring in the context. If no such substring exists, the question has no answer.

Let’s take a look at one sample in the dataset:

from datasets import load_dataset

dataset = load_dataset("squad")

print("Train size:", len(dataset["train"]))

print("Validation size:", len(dataset["validation"]))

# print one sample

print(dataset["train"][42])

Running this code, the output is:

Train size: 87599

Validation size: 10570

{'id': '5733ae924776f41900661016', 'title': 'University_of_Notre_Dame',

'context': 'Notre Dame is known for its competitive admissions, ...',

'question': 'What percentage of students at Notre Dame participated in the Early

Action program?', 'answers': {'text': ['39.1%'], 'answer_start': [488]}}

The SQuAD dataset has only training and validation splits. Each sample is a dictionary with the keys "id", "title", "context", "question", and "answers". The "answers" key is a dictionary containing the answer text and its offset in the context.

To train a model, you need to batch and process the data samples into tensors as you did for the GLUE tasks. Let’s create a custom collate function for the SQuAD dataset:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

...

def collate(batch: list[dict], tokenizer: tokenizers.Tokenizer, max_len: int):

    cls_id = tokenizer.token_to_id("[CLS]")

    sep_id = tokenizer.token_to_id("[SEP]")

    pad_id = tokenizer.token_to_id("[PAD]")

    input_ids_list = []

    token_type_ids_list = []

    start_positions = []

    end_positions = []

    for item in batch:

        # Tokenize question and context

        question, context = item["question"], item["context"]

        question_ids = tokenizer.encode(question).ids

        context_ids = tokenizer.encode(context).ids

        # Build input: [CLS] question [SEP] context [SEP]

        input_ids = [cls_id, *question_ids, sep_id, *context_ids, sep_id]

        token_type_ids = [0] * (len(question_ids)+2) + [1] * (len(context_ids)+1)

        # Truncate or pad to max length

        if len(input_ids) > max_len:

            input_ids = input_ids[:max_len]

            token_type_ids = token_type_ids[:max_len]

        else:

            input_ids.extend([pad_id] * (max_len - len(input_ids)))

            token_type_ids.extend([1] * (max_len - len(token_type_ids)))

        # Find answer position in tokens: Answer may not be in the context

        start_pos = end_pos = 0

        if len(item["answers"]["text"]) > 0:

            answers = tokenizer.encode(item["answers"]["text"][0]).ids

            # find the context offset of the answer in context_ids

            for i in range(len(context_ids) - len(answers) + 1):

                if context_ids[i:i+len(answers)] == answers:

                    start_pos = i + len(question_ids) + 2

                    end_pos = start_pos + len(answers) - 1

                    break

            if end_pos >= max_len:

                start_pos = end_pos = 0  # answer is clipped, hence no answer

        input_ids_list.append(input_ids)

        token_type_ids_list.append(token_type_ids)

        start_positions.append(start_pos)

        end_positions.append(end_pos)

    input_ids_list = torch.tensor(input_ids_list)

    token_type_ids_list = torch.tensor(token_type_ids_list)

    start_positions = torch.tensor(start_positions)

    end_positions = torch.tensor(end_positions)

    return (input_ids_list, token_type_ids_list, start_positions, end_positions)

batch_size = 16

max_len = 384  # Longer for Q&A to accommodate context

collate_fn = functools.partial(collate, tokenizer=tokenizer, max_len=max_len)

train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size,

                                           shuffle=True, collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(dataset["validation"], batch_size=batch_size,

                                         shuffle=False, collate_fn=collate_fn)

This collate function is more complex than the one for GLUE tasks because you need to pass a pair of sentences as input in the format [CLS] question [SEP] context [SEP]. The context may be clipped to fit the maximum length. The question, context, and answer are all converted into token sequences. In the inner for-loop, you find the position of the answer span in the context. If the answer is not found within the provided context, you mark the question as having no answer.

One of the most important roles of the collate function is to create tensors for the whole batch. Here you produce four tensors: the input (including both question and context), the token type IDs (denoting which tokens belong to the question and which belong to the context), and the start and end positions of the answer span.

Next, you need to set up a model for SQuAD tasks. The strategy is to process the sequence output from the BERT model such that each token is transformed into a probability of being the start or end of the answer span. You can then look for the highest probability start and end tokens to form the answer span. The implementation is straightforward:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

...

class BertForQuestionAnswering(nn.Module):

    """BERT model for SQuAD question answering."""

    def __init__(self, config):

        super().__init__()

        self.bert = BertModel(config)

        # Two outputs: start and end position logits

        self.qa_outputs = nn.Linear(config.hidden_size, 2)

    def forward(self, input_ids, token_type_ids, pad_id: int = 0):

        # Get sequence output from BERT (batch_size, seq_len, hidden_size)

        seq_output, pooled_output = self.bert(input_ids, token_type_ids, pad_id=pad_id)

        # Project to start and end logits

        logits = self.qa_outputs(seq_output)  # (batch_size, seq_len, 2)

        start_logits = logits[:, :, 0]  # (batch_size, seq_len)

        end_logits = logits[:, :, 1]    # (batch_size, seq_len)

        return start_logits, end_logits

The foundation BERT model produces both sequence output and pooled output. You use only the sequence output for SQuAD tasks. You process the output through a linear layer to produce the start and end position logits and return them separately. To convert them to probabilities, you should apply the softmax function to the logits, which can be done outside the model. This model for fine-tuning follows the architecture (c) in the figure of the previous section.

As in the example for GLUE tasks above, you can instantiate the model and load the pretrained weights:

...

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = BertConfig()

model = BertForQuestionAnswering(config)

model.to(device)

model.bert.load_state_dict(torch.load("bert_model.pth", map_location=device))

And finally, you can run the training loop for fine-tuning:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

..

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=2e-5)

num_epochs = 3

for epoch in range(num_epochs):

    model.train()

    # Training

    with tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:

        for batch in pbar:

            # get batched data

            input_ids, token_type_ids, start_positions, end_positions = batch

            input_ids = input_ids.to(device)

            token_type_ids = token_type_ids.to(device)

            start_positions = start_positions.to(device)

            end_positions = end_positions.to(device)

            # forward pass

            start_logits, end_logits = model(input_ids, token_type_ids)

            # backward pass

            optimizer.zero_grad()

            start_loss = loss_fn(start_logits, start_positions)

            end_loss = loss_fn(end_logits, end_positions)

            loss = start_loss + end_loss

            loss.backward()

            optimizer.step()

            # update progress bar

            pbar.set_postfix(loss=float(loss))

            pbar.update(1)

    # Validation: Keep track of the average loss and accuracy

    model.eval()

    val_loss, num_matches, num_batches, num_samples = 0, 0, 0, 0

    with torch.no_grad():

        for batch in val_loader:

            # get batched data

            input_ids, token_type_ids, start_positions, end_positions = batch

            input_ids = input_ids.to(device)

            token_type_ids = token_type_ids.to(device)

            start_positions = start_positions.to(device)

            end_positions = end_positions.to(device)

            # forward pass on validation data

            start_logits, end_logits = model(input_ids, token_type_ids)

            # compute loss

            start_loss = loss_fn(start_logits, start_positions)

            end_loss = loss_fn(end_logits, end_positions)

            loss = start_loss + end_loss

            val_loss += loss.item()

            num_batches += 1

            # compute accuracy

            pred_start = start_logits.argmax(dim=-1)

            pred_end = end_logits.argmax(dim=-1)

            match = (pred_start == start_positions) & (pred_end == end_positions)

            num_matches += match.sum().item()

            num_samples += len(start_positions)

    avg_loss = val_loss / num_batches

    acc = num_matches / num_samples

    print(f"Validation {epoch+1}/{num_epochs}: acc {acc:.4f}, avg loss {avg_loss:.4f}")

The training loop is similar to the one for GLUE tasks. Instead of using the pooled output, you now use the output corresponding to each token in the sequence. The token with the highest logit value is the predicted start or end position. The loss function is the sum of the cross-entropy losses for the predicted start and end positions.

This is a simplified way to use the model’s output. You can refine the logic by finding the start-end pair with the highest combined score under the constraint that the end position is greater than or equal to the start position. This may improve the model’s performance.

Running this code, you may see:

Epoch 1/3: 100%|████████████████████████████| 5475/5475 [07:45<00:00, 11.77it/s, loss=9.7]

Validation 1/3: acc 0.0189, avg loss 8.6972

Epoch 2/3: 100%|███████████████████████████| 5475/5475 [07:44<00:00, 11.78it/s, loss=8.37]

Validation 2/3: acc 0.0358, avg loss 8.2596

Epoch 3/3: 100%|███████████████████████████| 5475/5475 [07:44<00:00, 11.78it/s, loss=7.85]

Validation 3/3: acc 0.0449, avg loss 7.9882

You may notice that the model’s performance is not very good. This is likely because the foundation BERT model was trained on the smaller WikiText-2 dataset, which does not generalize well to more complex tasks. For better performance in practical applications, you should use the official pretrained weights instead.

Below is the complete code for fine-tuning a BERT model for SQuAD tasks:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

import collections

import dataclasses

import functools

import torch

import torch.nn as nn

import torch.optim as optim

import tqdm

from datasets import load_dataset

from tokenizers import Tokenizer

from torch import Tensor

# BERT config and model defined previously

@dataclasses.dataclass

class BertConfig:

    """Configuration for BERT model."""

    vocab_size: int = 30522

    num_layers: int = 12

    hidden_size: int = 768

    num_heads: int = 12

    dropout_prob: float = 0.1

    pad_id: int = 0

    max_seq_len: int = 512

    num_types: int = 2

class BertBlock(nn.Module):

    """One transformer block in BERT."""

    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

        super().__init__()

        self.attention = nn.MultiheadAttention(hidden_size, num_heads,

                                               dropout=dropout_prob, batch_first=True)

        self.attn_norm = nn.LayerNorm(hidden_size)

        self.ff_norm = nn.LayerNorm(hidden_size)

        self.dropout = nn.Dropout(dropout_prob)

        self.feed_forward = nn.Sequential(

            nn.Linear(hidden_size, 4 * hidden_size),

            nn.GELU(),

            nn.Linear(4 * hidden_size, hidden_size),

        )

    def forward(self, x: Tensor, pad_mask: Tensor) -> Tensor:

        # self-attention with padding mask and post-norm

        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

        x = self.attn_norm(x + attn_output)

        # feed-forward with GeLU activation and post-norm

        ff_output = self.feed_forward(x)

        x = self.ff_norm(x + self.dropout(ff_output))

        return x

class BertPooler(nn.Module):

    """Pooler layer for BERT to process the [CLS] token output."""

    def __init__(self, hidden_size: int):

        super().__init__()

        self.dense = nn.Linear(hidden_size, hidden_size)

        self.activation = nn.Tanh()

    def forward(self, x: Tensor) -> Tensor:

        x = self.dense(x)

        x = self.activation(x)

        return x

class BertModel(nn.Module):

    """Backbone of BERT model."""

    def __init__(self, config: BertConfig):

        super().__init__()

        # embedding layers

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

                                            padding_idx=config.pad_id)

        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

        self.embeddings_norm = nn.LayerNorm(config.hidden_size)

        self.embeddings_dropout = nn.Dropout(config.dropout_prob)

        # transformer blocks

        self.blocks = nn.ModuleList([

            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

            for _ in range(config.num_layers)

        ])

        # [CLS] pooler layer

        self.pooler = BertPooler(config.hidden_size)

    def forward(self, input_ids: Tensor, token_type_ids: Tensor, pad_id: int = 0,

                ) -> tuple[Tensor, Tensor]:

        # create attention mask for padding tokens

        pad_mask = input_ids == pad_id

        # convert integer tokens to embedding vectors

        batch_size, seq_len = input_ids.shape

        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.position_embeddings(position_ids)

        type_embeddings = self.type_embeddings(token_type_ids)

        token_embeddings = self.word_embeddings(input_ids)

        x = token_embeddings + type_embeddings + position_embeddings

        x = self.embeddings_norm(x)

        x = self.embeddings_dropout(x)

        # process the sequence with transformer blocks

        for block in self.blocks:

            x = block(x, pad_mask)

        # pool the hidden state of the `[CLS]` token

        pooled_output = self.pooler(x[:, 0, :])

        return x, pooled_output

# Define new BERT model for question answering

class BertForQuestionAnswering(nn.Module):

    """BERT model for SQuAD question answering."""

    def __init__(self, config: BertConfig):

        super().__init__()

        self.bert = BertModel(config)

        # Two outputs: start and end position logits

        self.qa_outputs = nn.Linear(config.hidden_size, 2)

    def forward(self,

        input_ids: Tensor,

        token_type_ids: Tensor,

        pad_id: int = 0,

    ) -> tuple[Tensor, Tensor]:

        # Get sequence output from BERT (batch_size, seq_len, hidden_size)

        seq_output, pooled_output = self.bert(input_ids, token_type_ids, pad_id=pad_id)

        # Project to start and end logits

        logits = self.qa_outputs(seq_output)  # (batch_size, seq_len, 2)

        start_logits = logits[:, :, 0]  # (batch_size, seq_len)

        end_logits = logits[:, :, 1]    # (batch_size, seq_len)

        return start_logits, end_logits

# Load SQuAD dataset for question answering

dataset = load_dataset("squad")

# Load the pretrained BERT tokenizer

TOKENIZER_PATH = "wikitext-2_wordpiece.json"

tokenizer = Tokenizer.from_file(TOKENIZER_PATH)

# Setup collate function to tokenize question-context pairs for the model

def collate(batch: list[dict], tokenizer: Tokenizer, max_len: int,

            ) -> tuple[Tensor, Tensor, Tensor, Tensor]:

    """Collate question-context pairs for the model."""

    cls_id = tokenizer.token_to_id("[CLS]")

    sep_id = tokenizer.token_to_id("[SEP]")

    pad_id = tokenizer.token_to_id("[PAD]")

    input_ids_list = []

    token_type_ids_list = []

    start_positions = []

    end_positions = []

    for item in batch:

        # Tokenize question and context

        question, context = item["question"], item["context"]

        question_ids = tokenizer.encode(question).ids

        context_ids = tokenizer.encode(context).ids

        # Build input: [CLS] question [SEP] context [SEP]

        input_ids = [cls_id, *question_ids, sep_id, *context_ids, sep_id]

        token_type_ids = [0] * (len(question_ids)+2) + [1] * (len(context_ids)+1)

        # Truncate or pad to max length

        if len(input_ids) > max_len:

            input_ids = input_ids[:max_len]

            token_type_ids = token_type_ids[:max_len]

        else:

            input_ids.extend([pad_id] * (max_len - len(input_ids)))

            token_type_ids.extend([1] * (max_len - len(token_type_ids)))

        # Find answer position in tokens: Answer may not be in the context

        start_pos = end_pos = 0

        if len(item["answers"]["text"]) > 0:

            answers = tokenizer.encode(item["answers"]["text"][0]).ids

            # find the context offset of the answer in context_ids

            for i in range(len(context_ids) - len(answers) + 1):

                if context_ids[i:i+len(answers)] == answers:

                    start_pos = i + len(question_ids) + 2

                    end_pos = start_pos + len(answers) - 1

                    break

            if end_pos >= max_len:

                start_pos = end_pos = 0  # answer is clipped, hence no answer

        input_ids_list.append(input_ids)

        token_type_ids_list.append(token_type_ids)

        start_positions.append(start_pos)

        end_positions.append(end_pos)

    input_ids_list = torch.tensor(input_ids_list)

    token_type_ids_list = torch.tensor(token_type_ids_list)

    start_positions = torch.tensor(start_positions)

    end_positions = torch.tensor(end_positions)

    return (input_ids_list, token_type_ids_list, start_positions, end_positions)

batch_size = 16

max_len = 384  # Longer for Q&A to accommodate context

collate_fn = functools.partial(collate, tokenizer=tokenizer, max_len=max_len)

train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size,

                                           shuffle=True, collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(dataset["validation"], batch_size=batch_size,

                                         shuffle=False, collate_fn=collate_fn)

# Create Q&A model with a pretrained foundation BERT model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = BertConfig()

model = BertForQuestionAnswering(config)

model.to(device)

model.bert.load_state_dict(torch.load("bert_model.pth", map_location=device))

# Training setup

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=2e-5)

num_epochs = 3

for epoch in range(num_epochs):

    model.train()

    # Training

    with tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:

        for batch in pbar:

            # get batched data

            input_ids, token_type_ids, start_positions, end_positions = batch

            input_ids = input_ids.to(device)

            token_type_ids = token_type_ids.to(device)

            start_positions = start_positions.to(device)

            end_positions = end_positions.to(device)

            # forward pass

            start_logits, end_logits = model(input_ids, token_type_ids)

            # backward pass

            optimizer.zero_grad()

            start_loss = loss_fn(start_logits, start_positions)

            end_loss = loss_fn(end_logits, end_positions)

            loss = start_loss + end_loss

            loss.backward()

            optimizer.step()

            # update progress bar

            pbar.set_postfix(loss=float(loss))

            pbar.update(1)

    # Validation: Keep track of the average loss and accuracy

    model.eval()

    val_loss, num_matches, num_batches, num_samples = 0, 0, 0, 0

    with torch.no_grad():

        for batch in val_loader:

            # get batched data

            input_ids, token_type_ids, start_positions, end_positions = batch

            input_ids = input_ids.to(device)

            token_type_ids = token_type_ids.to(device)

            start_positions = start_positions.to(device)

            end_positions = end_positions.to(device)

            # forward pass on validation data

            start_logits, end_logits = model(input_ids, token_type_ids)

            # compute loss

            start_loss = loss_fn(start_logits, start_positions)

            end_loss = loss_fn(end_logits, end_positions)

            loss = start_loss + end_loss

            val_loss += loss.item()

            num_batches += 1

            # compute accuracy

            pred_start = start_logits.argmax(dim=-1)

            pred_end = end_logits.argmax(dim=-1)

            match = (pred_start == start_positions) & (pred_end == end_positions)

            num_matches += match.sum().item()

            num_samples += len(start_positions)

    avg_loss = val_loss / num_batches

    acc = num_matches / num_samples

    print(f"Validation {epoch+1}/{num_epochs}: acc {acc:.4f}, avg loss {avg_loss:.4f}")

# Save the fine-tuned model

torch.save(model.state_dict(), f"bert_model_squad.pth")

Further Readings

Below are some resources that you may find useful:

Summary

In this article, you learned how to fine-tune a BERT model for GLUE and SQuAD tasks. Specifically, you learned:

  • How to build a new model on top of BERT for fine-tuning
  • How to run the training loop for fine-tuning
  • The GLUE and SQuAD datasets and the tasks they are designed for

No comments yet.

Read Entire Article

         

        

Start the new Vibrations with a Medbed Franchise today!  

Protect your whole family with Quantum Orgo-Life® devices

  Advertising by Adpathway