DistilBERT 모델을 활용한 텍스트 분류

2024. 7. 15. 12:59·LLM
728x90
반응형

1단계) 데이터 로드

from datasets import load_dataset
from datasets import ClassLabel
import pandas as pd

emotions = load_dataset("emotion", trust_remote_code=True)
emotions['train'].features['label'] = ClassLabel(
    num_classes=6,
    names=['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'])
    
emotions.set_format(type="pandas")
df = emotions["train"][:]

def label_int2str(row):
    return emotions["train"].features["label"].int2str(row)
    
df["label_name"] = df["label"].apply(label_int2str)

 

2단계) WordPiece 토큰화

from transformers import AutoTokenizer

model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True)

emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None)

 

3단계) 분류 모델 훈련 (Feature Extractor: Transformer)

from transformers import AutoModel
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier

model_ckpt = "distilbert-base-uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_ckpt).to(device)

def extract_hidden_states(batch):
    inputs = {k:v.to(device) for k,v in batch.items()
              if k in tokenizer.model_input_names}
    with torch.no_grad():
        last_hidden_state = model(**inputs).last_hidden_state
    return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}
    
emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
emotions_hidden = emotions_encoded.map(extract_hidden_states, batched=True)

X_train = np.array(emotions_hidden["train"]["hidden_state"])
X_valid = np.array(emotions_hidden["validation"]["hidden_state"])
y_train = np.array(emotions_hidden["train"]["label"])
y_valid = np.array(emotions_hidden["validation"]["label"])

# Logistic 회귀 모델 
lr_clf = LogisticRegression(max_iter=3000)
lr_clf.fit(X_train, y_train)
lr_clf.score(X_valid, y_valid) # 0.6335

# 랜덤 분류기 (DummyClassifier)
dummy_clf = DummyClassifier(strategy="most_frequent")
dummy_clf.fit(X_train, y_train)
dummy_clf.score(X_valid, y_valid) # 0.352

 

4단계) 분류 모델 훈련 (전체 모델 Fine-tuning)

from transformers import AutoModelForSequenceClassification
from transformers import Trainer, TrainingArguments
from transformers import Trainer, TrainingArguments

num_labels = 6
model = (AutoModelForSequenceClassification
         .from_pretrained(model_ckpt, num_labels=num_labels)
         .to(device))
         
batch_size = 64
logging_steps = len(emotions_encoded["train"]) // batch_size
model_name = f"{model_ckpt}-finetuned-emotion"
training_args = TrainingArguments(output_dir=model_name,
                                  num_train_epochs=2,
                                  learning_rate=2e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy="epoch",
                                  disable_tqdm=False,
                                  logging_steps=logging_steps,
                                  push_to_hub=True,
                                  save_strategy="epoch",
                                  load_best_model_at_end=True,
                                  log_level="error")
                                            
trainer = Trainer(model=model, args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=emotions_encoded["train"],
                  eval_dataset=emotions_encoded["validation"],
                  tokenizer=tokenizer)
trainer.train()

preds_output = trainer.predict(emotions_encoded["validation"])
preds_output.metrics
#{'test_loss': 0.2192373424768448,
#'test_accuracy': 0.9275,
#'test_f1': 0.9277315829088285,
#'test_runtime': 3.8546,
#'test_samples_per_second': 518.866,
#'test_steps_per_second': 8.302}

5단계) 오류 분석

from torch.nn.functional import cross_entropy

def forward_pass_with_label(batch):
    # 모든 입력 텐서를 모델과 같은 장치로 이동합니다.
    inputs = {k:v.to(device) for k,v in batch.items()
              if k in tokenizer.model_input_names}

    with torch.no_grad():
        output = model(**inputs)
        pred_label = torch.argmax(output.logits, axis=-1)
        loss = cross_entropy(output.logits, batch["label"].to(device),
                             reduction="none")

    # 다른 데이터셋 열과 호환되도록 출력을 CPU로 옮깁니다.
    return {"loss": loss.cpu().numpy(),
            "predicted_label": pred_label.cpu().numpy()}
            
emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
emotions_encoded["validation"] = emotions_encoded["validation"].map(forward_pass_with_label, batched=True, batch_size=16)

emotions_encoded.set_format("pandas")
cols = ["text", "label", "predicted_label", "loss"]
df_test = emotions_encoded["validation"][:][cols]
df_test["label"] = df_test["label"].apply(label_int2str)
df_test["predicted_label"] = (df_test["predicted_label"].apply(label_int2str))

df_test.sort_values("loss", ascending=False).head()

 

6단계) 모델 로드 및 테스트

from transformers import pipeline

model_id = "haesun/distilbert-base-uncased-finetuned-emotion"
classifier = pipeline("text-classification", model=model_id)

custom_tweet = "I saw a movie today and it was really good."
preds = classifier(custom_tweet, top_k=None)

preds_sorted = sorted(preds, key=lambda d: d['label'])
preds_df = pd.DataFrame(preds_sorted)
plt.bar(labels, 100 * preds_df["score"], color='C0')
plt.title(f'"{custom_tweet}"')
plt.ylabel("Class probability (%)")
plt.show()

 

본 게시글은 '트랜스포머를 활용한 자연어 처리' 교재 2장 내용을 기반으로 작성되었습니다

728x90
반응형

'LLM' 카테고리의 다른 글

[논문 리뷰] RoBERTa: A Robustly Optimized BERT Pretraining Approach (2019)  (0) 2024.07.27
'LLM' 카테고리의 다른 글
  • [논문 리뷰] RoBERTa: A Robustly Optimized BERT Pretraining Approach (2019)
View synthesis 공부하는 대학원생
View synthesis 공부하는 대학원생
AI - view synthesis에 대해 공부하고 있으며, AI 공부하시는 분들과 함께 소통하고 싶습니다 😍
  • View synthesis 공부하는 대학원생
    Happy Support's Blog
    View synthesis 공부하는 대학원생
  • 전체
    오늘
    어제
    • 분류 전체보기 (63)
      • View synthesis (3)
      • Backbone (5)
      • Generative Models (5)
      • On-device AI (3)
      • ML (2)
      • DL (1)
      • LLM (2)
      • 코딩테스트 (25)
      • 에러 해결 모음집 (12)
      • 기타 (4)
  • 링크

  • 인기 글

  • 최근 댓글

  • 최근 글

  • 250x250
    반응형
  • hELLO· Designed By정상우.v4.10.3
View synthesis 공부하는 대학원생
DistilBERT 모델을 활용한 텍스트 분류
상단으로

티스토리툴바