7.1. Налаштування для класифікації

Reading time: 5 minutes

Що таке

Налаштування - це процес взяття попередньо навченої моделі, яка навчилася загальним мовним патернам з величезних обсягів даних, і адаптації її для виконання конкретного завдання або для розуміння специфічної мови домену. Це досягається шляхом продовження навчання моделі на меншому, специфічному для завдання наборі даних, що дозволяє їй налаштувати свої параметри для кращого відповідності нюансам нових даних, використовуючи при цьому широкі знання, які вона вже здобула. Налаштування дозволяє моделі надавати більш точні та релевантні результати в спеціалізованих застосуваннях без необхідності навчати нову модель з нуля.

tip

Оскільки попереднє навчання LLM, яка "розуміє" текст, є досить дорогим, зазвичай легше і дешевше налаштувати відкриті попередньо навченої моделі для виконання конкретного завдання, яке ми хочемо, щоб вона виконувала.

tip

Мета цього розділу - показати, як налаштувати вже попередньо навчена модель, щоб замість генерації нового тексту LLM вибирала ймовірності того, що даний текст буде класифіковано в кожну з наданих категорій (наприклад, чи є текст спамом чи ні).

Підготовка набору даних

Розмір набору даних

Звичайно, для налаштування моделі вам потрібні структуровані дані, щоб спеціалізувати ваш LLM. У прикладі, запропонованому в https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 налаштовується для виявлення, чи є електронний лист спамом, чи ні, використовуючи дані з https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.

Цей набір даних містить набагато більше прикладів "не спаму", ніж "спаму", тому книга пропонує використовувати лише стільки ж прикладів "не спаму", скільки "спаму" (отже, видаливши з навчальних даних всі додаткові приклади). У цьому випадку це було 747 прикладів кожного.

Потім 70% набору даних використовується для навчання, 10% для перевірки та 20% для тестування.

  • Набір для перевірки використовується під час навчального етапу для налаштування гіперпараметрів моделі та прийняття рішень щодо архітектури моделі, ефективно допомагаючи запобігти перенавчанню, надаючи зворотний зв'язок про те, як модель працює з невідомими даними. Це дозволяє здійснювати ітеративні поліпшення без упередження фінальної оцінки.
  • Це означає, що хоча дані, включені в цей набір даних, не використовуються безпосередньо для навчання, вони використовуються для налаштування найкращих гіперпараметрів, тому цей набір не може бути використаний для оцінки продуктивності моделі, як тестовий.
  • На відміну від цього, тестовий набір використовується тільки після того, як модель була повністю навчена і всі налаштування завершені; він надає неупереджену оцінку здатності моделі узагальнювати нові, невідомі дані. Ця фінальна оцінка на тестовому наборі дає реалістичне уявлення про те, як модель очікується працювати в реальних застосуваннях.

Довжина записів

Оскільки навчальний приклад очікує записи (тексти електронних листів у цьому випадку) однакової довжини, було вирішено зробити кожен запис таким же великим, як найбільший, додавши ідентифікатори <|endoftext|> як заповнювач.

Ініціалізація моделі

Використовуючи відкриті попередньо навченої ваги, ініціалізуйте модель для навчання. Ми вже робили це раніше і, дотримуючись інструкцій https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, ви можете легко це зробити.

Головка класифікації

У цьому конкретному прикладі (прогнозування, чи є текст спамом, чи ні) нас не цікавить налаштування відповідно до повного словника GPT2, але ми лише хочемо, щоб нова модель сказала, чи є електронний лист спамом (1), чи ні (0). Тому ми збираємося модифікувати фінальний шар, який надає ймовірності для кожного токена словника, на той, який лише надає ймовірності бути спамом або ні (тобто як словник з 2 слів).

python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

Параметри для налаштування

Щоб швидко налаштувати, легше не налаштовувати всі параметри, а лише деякі фінальні. Це пов'язано з тим, що відомо, що нижчі шари зазвичай захоплюють основні мовні структури та семантику. Тому просто налаштування останніх шарів зазвичай є достатнім і швидшим.

python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

Entries to use for training

В попередніх розділах LLM навчався, зменшуючи втрати кожного передбаченого токена, хоча майже всі передбачені токени були в вхідному реченні (лише 1 в кінці дійсно передбачався), щоб модель краще розуміла мову.

У цьому випадку нас цікавить лише здатність моделі передбачити, чи є модель спамом, чи ні, тому ми звертаємо увагу лише на останній передбачений токен. Отже, потрібно модифікувати наші попередні функції втрат навчання, щоб враховувати лише цей токен.

Це реалізовано в https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb як:

python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

Зверніть увагу, що для кожної партії нас цікавлять лише логіти останнього передбаченого токена.

Повний код для тонкої настройки класифікації GPT2

Ви можете знайти весь код для тонкої настройки GPT2 як класифікатора спаму в https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb

Посилання