โ๏ธ Summary
โ ๋ฐ์ด์ธ์ฝ๋ + ๊ต์ฐจ์ธ์ฝ๋ ๊ฒฐํฉ์ ํตํ RAG ์ฑ๋ฅ ํฅ์
โ ์ฌ์ ํ์ต๋ ์ธ์ด๋ชจ๋ธ + Pooling์ธต + ํ์ธํ๋
1. ๊ฒ์ ์ฑ๋ฅ์ ๋์ด๊ธฐ ์ํ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ
โ ๋ฐ์ด์ธ์ฝ๋์ ๊ต์ฐจ์ธ์ฝ๋๋ฅผ ๊ฒฐํฉํด ๋ ๋น ๋ฅธ ๊ฒ์์ ์ํ
โช๏ธ 1) ๋ฐ์ด์ธ์ฝ๋๋ฅผ ์ฌ์ฉํด ๋๊ท๋ชจ์ ๋ฌธ์์์ ๊ฒ์ ์ฟผ๋ฆฌ์ ์ ์ฌํ ์์์ ๋ฌธ์(ex.์์ 100๊ฐ)๋ฅผ ์ ๋ณ
โช๏ธ 2) ์๋ฏธ๊ฒ์์ ํตํด ์ ๋ณํ ์์์ ๋ฌธ์๋ ์ ์ฌ๋๋ฅผ ๋ ์ ํํ ๊ณ์ฐํ ์ ์๋ ๊ต์ฐจ ์ธ์ฝ๋๋ฅผ ์ฌ์ฉํด ์ ์ฌํ ์์๋๋ก ์ฌ์ ๋ ฌ (Rerank)

โ ๋ฐ์ด์ธ์ฝ๋์ ๊ต์ฐจ์ธ์ฝ๋๋ฅผ ๊ฒฐํฉํ ๋, ๊ฒ์ ์ฑ๋ฅ์ ๋ ๋์ด๋ ๋ฐฉ๋ฒ
โช๏ธ 1) ๋ฐ์ด์ธ์ฝ๋ ์ถ๊ฐ ํ์ต : ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ๋ ํ์ต ๋ฐ์ดํฐ์ ์ ์ฌํ ์ ๋ ฅ์ ๋ํด ์ ์๋ํ๋ฏ๋ก, ์ฌ์ฉํ๋ ค๋ ๋ฐ์ดํฐ์ ์ผ๋ก ์ถ๊ฐํ์ต์ ์ํจ๋ค.
โช๏ธ 2) ๊ต์ฐจ์ธ์ฝ๋ ์ถ๊ฐ ํ์ต : ๊ฒ์๋ ๋ชจ๋ ๋ฌธ์๊ฐ ์๋๋ผ, ์์ ๋ช ๊ฐ์ ์ ๋ ฅ๋ง ํ๋กฌํํธ์ ์ถ๊ฐํด ๊ฒ์ ์ฆ๊ฐ ์์ฑ์ด ํจ๊ณผ์ ์ผ๋ก ์๋ํ๋๋ก ํจ
2. ์ธ์ด๋ชจ๋ธ์ ์๋ฒ ๋ฉ ๋ชจ๋ธ๋ก ๋ง๋ค๊ธฐ

โช๏ธ ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ 2๊ฐ์ ์ธต์ผ๋ก ๋๋๋ค. ์ฒซ ๋ฒ์งธ ์ธต์ ๋๋์ ํ ์คํธ ๋ฐ์ดํฐ๋ก ์ฌ์ ํ์ตํ BERT๋ RoBERTa ๊ฐ์ ์ธ์ด๋ชจ๋ธ์ด๋ค. ๋๋ฒ์งธ ์ธต์ ํ๋ง์ธต์ผ๋ก ์ ๋ ฅ ๋ฌธ์ฅ์ ๊ธธ์ด์ ๋ฐ๋ผ ๋ฌ๋ผ์ง ์ ์๋ ์ถ๋ ฅ ์ฐจ์์ ๊ณ ์ ๋ ์ฐจ์์ผ๋ก ๋ง์ถ๋ ์ญํ ์ ํ๋ค. ํ๋ง์ ๋ฐฉ์์๋ ํด๋์ค๋ชจ๋, ํ๊ท ๋ชจ๋, ์ต๋๋ชจ๋๊ฐ ์๋๋ฐ ์ผ๋ฐ์ ์ผ๋ก ํ๊ท ๋ชจ๋๋ฅผ ๋ง์ด ์ฌ์ฉํ๋ค. ์ด์ฒ๋ผ ์ฌ์ ํ์ต๋ ์ธ์ด๋ชจ๋ธ์ ๋ถ๋ฌ์ค๊ณ ๊ทธ ์์ ํ๋ง ์ธต์ ์ถ๊ฐํ๊ณ ๋ฌธ์ฅ์ ์๋ฏธ๋ฅผ ์ ๋ด์ ์ ์๋๋ก ํ์ตํด์ผ ํ๋ค.
โช๏ธ Sentence-Tranformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ฝ๊ฒ ํ์ฉํ ์ ์๋ค.
2.1 ๋์กฐํ์ต
โ Contrastive learning
โช๏ธ ๊ด๋ จ์ด ์๊ฑฐ๋ ์ ์ฌํ ๋ฐ์ดํฐ๋ ๋ ๊ฐ๊น์์ง๋๋ก ๋ง๋ค๊ณ ๊ด๋ จ์ด ์๊ฑฐ๋ ์ ์ฌํ์ง ์์ ๋ฐ์ดํฐ๋ ๋ ๋ฉ์ด์ง๋๋ก ํ๋ ํ์ต ๋ฐฉ์
โช๏ธ ๋์กฐํ์ต ๋ฐฉ์์ผ๋ก ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํ์ต์ํฌ ๋, ๋ค์ํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ ์ ์๋ค. 2๊ฐ์ ๋ฌธ์ฅ์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๊ฐ๊ฐ ์ ๋ ฅํ๊ณ ์๋ก ์ ์ฌํ ๋ฐ์ดํฐ์ธ ๊ฒฝ์ฐ๋ ๊ฐ๊น๊ฒ ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ๋ ๋ฉ๊ฒ ๋ง๋ค ์ ์๋ค. ๋๋ ์๋ก ์ด์ด์ง๋ ๋ฌธ์ฅ์ด๋ผ๋ฉด ๊ฐ๊น๊ฒ ๊ทธ๋ ์ง ์์ผ๋ฉด ์๋ก ๋ฉ๊ฒ ๋ง๋ค ์ ์๋ค. ๋ง์ง๋ง์ผ๋ก ์๋ก ์ง๋ฌธ๋ต๋ณ ๊ด๊ณ์ธ ๊ฒฝ์ฐ๋ผ๋ฉด ๊ฐ๊น๊ฒ ์๋๋ฉด ๋ฉ๊ฒ ํ์ต์ํฌ ์ ์๋ค.
2.2 ์ค์ต [1]
โ ์ธ์ด๋ชจ๋ธ์ ๊ทธ๋๋ก ๋ถ๋ฌ์ ๋ฌธ์ฅ ์๋ฒ ๋ฉ์ ๋ง๋ค์ด๋ณด๊ธฐ
1) ์ฌ์ ํ์ต๋ ์ธ์ด๋ชจ๋ธ์ ๋ถ๋ฌ์ ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ง๋ค๊ธฐ
from sentence_transformers import SentenceTransformer, models
# 1) modules ๋ชจ๋์ ํ์ฉํด ์ฌ์ ํ์ต๋ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
transformer_model = models.Transformer('klue/roberta-base')
# 2) ํ๊ท ํ๋ง์ธต ์์ฑ
pooling_layer = models.Pooling(
transformer_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True
)
# 3) SentenceTransformer ํด๋์ค๋ก ๋ ๋ชจ๋์ ๊ฒฐํฉํด ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ ์์ฑ
embedding_model = SentenceTransformer(modules=[transformer_model, pooling_layer])
2) ์ค์ต ๋ฐ์ดํฐ์ ๋ถ๋ฌ์ค๊ธฐ
โช๏ธ KLUE์ STS ๋ฐ์ดํฐ์ (๋ ๋ฌธ์ฅ์ด ์๋ก ์ผ๋ง๋ ์ ์ฌํ์ง ์ ์๋ฅผ ๋งค๊ธด ๋ฐ์ดํฐ์ )
โช๏ธ labels ์นผ๋ผ์ ๋ ๋ฌธ์ฅ์ด ์ผ๋ง๋ ์ ์ฌํ์ง๋ฅผ ๋ํ๋ด๋ ๋ค์ํ ํ์์ ๋ ์ด๋ธ์ด ์์ (label์ ์ฌ์ฉํ ์์ )
from datasets import load_dataset
klue_sts_train = load_dataset('klue', 'sts', split='train')
klue_sts_test = load_dataset('klue', 'sts', split='validation')
klue_sts_train[0]
# {'guid': 'klue-sts-v1_train_00000',
# 'source': 'airbnb-rtt',
# 'sentence1': '์์ ์์น๋ ์ฐพ๊ธฐ ์ฝ๊ณ ์ผ๋ฐ์ ์ธ ํ๊ตญ์ ๋ฐ์งํ ์์์
๋๋ค.',
# 'sentence2': '์๋ฐ์์ค์ ์์น๋ ์ฝ๊ฒ ์ฐพ์ ์ ์๊ณ ํ๊ตญ์ ๋ํ์ ์ธ ๋ฐ์งํ ์๋ฐ์์ค์
๋๋ค.',
# 'labels': {'label': 3.7, 'real-label': 3.714285714285714, 'binary-label': 1}}
# ํ์ต ๋ฐ์ดํฐ์
์ 10%๋ฅผ ๊ฒ์ฆ ๋ฐ์ดํฐ์
์ผ๋ก ๊ตฌ์ฑ
klue_sts_train = klue_sts_train.train_test_split(test_size=0.1, seed=42)
klue_sts_train, klue_sts_eval = klue_sts_train['train'], klue_sts_train['test']
3) label ์ ๊ทํํ๊ธฐ
from sentence_transformers import InputExample
# ์ ์ฌ๋ ์ ์๋ฅผ 0~1 ์ฌ์ด๋ก ์ ๊ทํ ํ๊ณ InputExample ๊ฐ์ฒด์ ๋ด๊ธฐ
def prepare_sts_examples(dataset):
examples = []
for data in dataset:
examples.append(
InputExample(
texts=[data['sentence1'], data['sentence2']], # ํ
์คํธ์์ ๋ฆฌ์คํธ ํํ๋ก ์
๋ ฅ
label=data['labels']['label'] / 5.0) # ์ ๊ทํ
)
return examples
# ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ์ํ
train_examples = prepare_sts_examples(klue_sts_train)
eval_examples = prepare_sts_examples(klue_sts_eval)
test_examples = prepare_sts_examples(klue_sts_test)
4) ๋ฐฐ์น ๋ฐ์ดํฐ์ ๋ง๋ค๊ธฐ
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
5) ๊ฒ์ฆ์ ์ํ ํ๊ฐ ๊ฐ์ฒด ์ค๋น ๋ฐ ๋ชจ๋ธ ์ฑ๋ฅ ๊ฒฐ๊ณผ
โช๏ธ EmbeddingSimilarityEvaluator ๋ฅผ ์ฌ์ฉํด ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ ๋ ์ฌ์ฉํ ์ ์๋๋ก ์ค๋น
โช๏ธ from_input_examples ๋ฉ์๋๋ฅผ ์ฌ์ฉํด, ๊ฒ์ฆ ๋ฐ์ดํฐ์ ๊ณผ ํ๊ฐ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๋ ํ๊ฐ ๊ฐ์ฒด๋ฅผ ์์ฑ
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
eval_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_examples)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_examples)
test_evaluator(embedding_model)
# 0.36460670798564826
โช๏ธ ์ธ์ด๋ชจ๋ธ์ ๊ทธ๋๋ก ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ๋ก ๋ง๋ embedding_model์ด ์ผ๋ง๋ ๋ฌธ์ฅ์ ์๋ฏธ๋ฅผ ์ ๋ฐ์ํด ๋ฌธ์ฅ ์๋ฒ ๋ฉ์ ์์ฑํ๋์ง ํ์ธํด๋ณด์์ ๋, 0.364๋ก ์ญํ ์ ์ํ์ง ๋ชปํ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
2.3 ์ค์ต [2] ์ ์ฌํ ๋ฌธ์ฅ ๋ฐ์ดํฐ๋ก ์๋ฒ ๋ฉ ๋ชจ๋ธ ํ์ตํ๊ธฐ
โ ์๋ฒ ๋ฉ ๋ชจ๋ธ ํ์ต
โช๏ธ ๊ธฐ๋ณธ ์ธ์ด๋ชจ๋ธ : klue/roberta-base
โช๏ธ ํ์ต์ ์ฌ์ฉ๋ ์์คํจ์ : CosineSimilarityLoss๋ฅผ ์ฌ์ฉํ๋ค. ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋ฌธ์ฅ ์๋ฒ ๋ฉ์ผ๋ก ๋ณํํ๊ณ ๋ ๋ฌธ์ฅ ์ฌ์ด์ ์ฝ์ฌ์ธ ์ ์ฌ๋์ ์ ๋ต ์ ์ฌ๋๋ฅผ ๋น๊ตํด ํ์ต์ ์ํํ๋ค.
from sentence_transformers import losses
num_epochs = 4
model_name = 'klue/roberta-base'
model_save_path = 'output/training_sts_' + model_name.replace("/", "-")
train_loss = losses.CosineSimilarityLoss(model=embedding_model)
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ํ์ต
embedding_model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=eval_evaluator,
epochs=num_epochs,
evaluation_steps=1000,
warmup_steps=100,
output_path=model_save_path
)
โ ํ์ตํ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ฑ๋ฅํ๊ฐ
โช๏ธ ํ์ต๋ ๋ชจ๋ธ์ด ์ ์ฅ๋ ๊ฒฝ๋ก์์ ๋ชจ๋ธ์ ์ฝ์ด์ค๊ณ , ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํ๊ฐํ๋ค. ์ ์๊ฐ 0.364์์ 0.896์ผ๋ก ํฌ๊ฒ ํฅ์๋ ๊ฒ์ ๋ณผ ์ ์๋ค.
trained_embedding_model = SentenceTransformer(model_save_path)
test_evaluator(trained_embedding_model)
# 0.8965595666246748
โ ํ๊น ํ์ด์ค ํ๋ธ์ ๋ชจ๋ธ ์ ์ฅ
โช๏ธ ๊ณ์ ํ ํฐ์ ํตํด ํ๊น ํ์ด์ค ํ๋ธ์ ์ ๊ทผํด, ๋ชจ๋ธ์ ์ ๋ก๋ํ๋ค.
from huggingface_hub import login
from huggingface_hub import HfApi
login(token='ํ๊น
ํ์ด์ค ํ๋ธ ํ ํฐ ์
๋ ฅ')
api = HfApi()
repo_id="klue-roberta-base-klue-sts"
api.create_repo(repo_id=repo_id)
# ๋ชจ๋ธ ์
๋ก๋
api.upload_folder(
folder_path=model_save_path,
repo_id=f"๋ณธ์ธ์ ํ๊น
ํ์ด์ค ์์ด๋ ์
๋ ฅ/{repo_id}",
repo_type="model",
)
3. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ
3.1 ํ์ต ์ค๋น
โ [๋ณต์ต] RAG
โช๏ธ RAG๋ ๊ฒ์ ์ฟผ๋ฆฌ์ ๊ด๋ จ๋ ๋ฌธ์๋ฅผ ์ฐพ์, LLM ํ๋กฌํํธ์ ๋งฅ๋ฝ ๋ฐ์ดํฐ๋ก ์ถ๊ฐํ ๋, ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํ์ฉํ๋ค. ์ข์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ด๋ผ๋ฉด ๊ฒ์ ์ฟผ๋ฆฌ์ ๊ด๋ จ์๋ ๋ฌธ์๋ ์ ์ฌ๋๊ฐ 1์ ๊ฐ๊น๊ฒ ๋์์ผ ํ๋ค.
โช๏ธ ์๋ฒ ๋ฉ๋ชจ๋ธ์ KLUE์ MRC ๋ฐ์ดํฐ์ (๊ธฐ์ฌ ๋ณธ๋ฌธ ๋ฐ ํด๋น ๊ธฐ์ฌ์ ๊ด๋ จ๋ ์ง๋ฌธ์ ์์งํ ๋ฐ์ดํฐ)์ผ๋ก ์ถ๊ฐํ์ต์์ผ ์ค์ต ๋ฐ์ดํฐ์ ๋ฌธ์ฅ ์ฌ์ด์ ์ ์ฌ๋๋ฅผ ๋ ์ ๊ณ์ฐํ ์ ์๋๋ก ๋ง๋ ๋ค.
โ ๋ฏธ์ธ์กฐ์
โช๏ธ ๋ฌธ์ฅ ์๋ฒ ๋ฉ ๋ชจ๋ธ๋, ๋ค๋ฅธ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ํ์ต ๋ฐ์ดํฐ์ ์ ์ฌํ ๋ฐ์ดํฐ์ผ ๋ ๊ฐ์ฅ ์ ๋์ํ๋ค.
โช๏ธ ๋ฐ๋ผ์, ์ฌ์ ํ์ต๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๊ทธ๋๋ก ํ์ฉํ๋ ๊ฒฝ์ฐ, ์ฌ์ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์ ์ด, ์ค์ต์ ์ฌ์ฉํ๋ MRC ๋ฐ์ดํฐ์ ๊ณผ ๋จ์ด, ์ฃผ์ ๋ฑ์ด ๋ค๋ฅด๋ฉด ์ฑ๋ฅ์ด ๋ฎ์์ง๋ค. ๋ฐ๋ผ์ MRC ๋ฐ์ดํฐ์ ์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํ์ฉํ๋ ค๊ณ ํ๋ค๋ฉด, ๊ทธ ๋ชฉ์ ์ ๋ง๊ฒ MRC ๋ฐ์ดํฐ์ ์ผ๋ก ๋ฏธ์ธ์กฐ์ ํด์ผ ํ๋ค.
โ ์ค์ต
1) 2์ฅ์์ ์ ์ฅํ ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts')
2) ์ง๋ฌธ๊ณผ ๊ด๋ จ์๋ ๊ธฐ์ฌ ์ถ๊ฐํ๊ธฐ
โช๏ธ question - context ์นผ๋ผ์ ๊ด๋ จ์ด ์์ผ๋ฏ๋ก label 1์ ๋ถ์ฌํ๊ณ , ์์๋ก ์ง๋ฌธ๊ณผ ๊ด๋ จ์๋ ๊ธฐ์ฌ๋ฅผ ์ถ๊ฐํด ๋ง๋ irrelevant_context ์นผ๋ผ์ ๊ด๋ จ์ด ์์ผ๋ฏ๋ก label 0์ ๋ถ์ฌ
def add_ir_context(df):
irrelevant_contexts = []
for idx, row in df.iterrows():
title = row['title']
irrelevant_contexts.append(df.query(f"title != '{title}'").sample(n=1)['context'].values[0])
df['irrelevant_context'] = irrelevant_contexts
return df
df_train_ir = add_ir_context(df_train)
3) ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ฏธ์ธ์กฐ์ ํ์ง ์์ ์ํ์์ ์ฑ๋ฅ ํ๊ฐ
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
examples
)
evaluator(sentence_model)
# 0.8151553052035344
3.2 ๋ฏธ์ธ์กฐ์
โ MNR์์ค์ ํ์ฉํด ๋ฏธ์ธ์กฐ์ ํ๊ธฐ
โช๏ธ ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ง๋ค ๋ ์ฝ์ฌ์ธ ์ ์ฌ๋ ์์ค์ ํ์ฉํด ๋ชจ๋ธ์ ํ์ต์์ผฐ์๋๋ฐ, ์ด๋ฒ์๋ Multiple Negatives Ranking ์์ค์ ์ฌ์ฉํด ๋ฏธ์ธ์กฐ์ ์ ์๋ํ๋ค.
โช๏ธ MNR์ MRC ๋ฐ์ดํฐ์ ๊ณผ ๊ฐ์ด ์๋ก ๊ด๋ จ์ด ์๋ ๋ฌธ์ฅ๋ง ์๋ ๊ฒฝ์ฐ ์ฌ์ฉํ๊ธฐ ์ข์ ์์คํจ์์ด๋ค. MNR ์์ค์ ์ฌ์ฉํ๋ฉด, ์๋์ผ๋ก ํ๋์ ๋ฐฐ์น ๋ฐ์ดํฐ ์์์ ๋ค๋ฅธ ๋ฐ์ดํฐ์ ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ๊ด๋ จ์๋ ๋ฐ์ดํฐ๋ก ์ฌ์ฉํด ๋ชจ๋ธ์ ํ์ต์ํค๊ธฐ ๋๋ฌธ์, ์๋ก ๊ด๋ จ์ด ์๋ ๋ฐ์ดํฐ๋ง์ผ๋ก ํ์ต ๋ฐ์ดํฐ๋ฅผ ๊ตฌ์ฑํ๋ฉด ๋๋ค.
1) ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
# ๊ธ์ ๋ฐ์ดํฐ๋ง์ผ๋ก ํ์ต ๋ฐ์ดํฐ ๊ตฌ์ฑ
train_samples = []
for idx, row in df_train_ir.iterrows():
train_samples.append(InputExample(
texts=[row['question'], row['context']]
))
# ์ค๋ณต ํ์ต ๋ฐ์ดํฐ ์ ๊ฑฐ
from sentence_transformers import datasets
batch_size = 16
loader = datasets.NoDuplicatesDataLoader(
train_samples, batch_size=batch_size)
2) MNR ์์คํจ์ ๋ถ๋ฌ์ค๊ธฐ ๋ฐ ๋ฏธ์ธ์กฐ์
from sentence_transformers import losses
#[์ฐธ๊ณ ] sentence_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts')
# ์์คํจ์ ๋ถ๋ฌ์ค๊ธฐ
loss = losses.MultipleNegativesRankingLoss(sentence_model)
# ๋ฏธ์ธ์กฐ์ ์ํ
epochs = 1
save_path = './klue_mrc_mnr'
sentence_model.fit(
train_objectives=[(loader, loss)],
epochs=epochs,
warmup_steps=100,
output_path=save_path,
show_progress_bar=True
)
3) ํ๊ฐ
โช๏ธ ๋ฏธ์ธ์กฐ์ ์ ์ 0.815์๋ ์ฑ๋ฅ์ด 0.86์ผ๋ก ์์น
evaluator(sentence_model)
# 0.8600968992433692
4) ๋ชจ๋ธ ์ ๋ก๋
from huggingface_hub import HfApi
api = HfApi()
repo_id = "klue-roberta-base-klue-sts-mrc"
api.create_repo(repo_id=repo_id)
api.upload_folder(
folder_path=save_path,
repo_id=f"๋ณธ์ธ์ ์์ด๋ ์
๋ ฅ/{repo_id}",
repo_type="model",
)
4. ๊ฒ์ ํ์ง์ ๋์ด๋ ์์ ์ฌ์ ๋ ฌ
โ ๊ต์ฐจ์ธ์ฝ๋ ๋ฏธ์ธ์กฐ์
โช๏ธ ๊ต์ฐจ์ธ์ฝ๋๋ 2๊ฐ ๋ฌธ์ฅ์ ์ ๋ ฅ๋ฐ์ ๋ฌธ์ฅ ์ฌ์ด์ ๊ด๊ณ๋ฅผ ํ์ตํ๋ฏ๋ก, ๋ฌธ์ฅ๋ถ๋ฅ๋ชจ๋ธ์ ์ฌ์ฉํ๋ค.
โช๏ธ ๋ฌธ์ฅ๋ถ๋ฅ ๋ชจ๋ธ์ด๋ผ, transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ๋ชจ๋ธ์ ์ง์ ํ์ตํ๋ ๋ฐฉ์๋ ๊ฐ๋ฅํ์ง๋ง, ์ค์ต์์๋ CrossEncoder์ ๋ฏธ์ธ์กฐ์ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋ค.
1) CrossEncoder ๋ถ๋ฌ์ค๊ธฐ
from sentence_transformers.cross_encoder import CrossEncoder
cross_model = CrossEncoder('klue/roberta-small', num_labels=1)
โช๏ธ ๊ต์ฐจ์ธ์ฝ๋๋ ๋ง์ ๊ณ์ฐ์ ํด์ผ ํ๋ฏ๋ก ํ๋ผ๋ฏธํฐ์๊ฐ ์์ roberta-small ๋ชจ๋ธ์ ์ฌ์ฉํ๋ค.
โช๏ธ roberta-small ๋ชจ๋ธ์ ๋ถ๋ฅํค๋๊ฐ ์๋ ์ธ์ด๋ชจ๋ธ์ด๋ฏ๋ก, ๊ต์ฐจ์ธ์ฝ๋๋ก ๋ถ๋ฌ์ค๋ ค๋ฉด ๋ถ๋ฅ ํค๋๋ ๋๋ค์ผ๋ก ์ด๊ธฐํ๋๋ค(์ฑ๋ฅ์ด ๋ฎ์์๋ฐ์ ์์).
2) ์ด๊ธฐ ์ฑ๋ฅ ๊ฒฐ๊ณผ
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
ce_evaluator = CECorrelationEvaluator.from_input_examples(examples)
ce_evaluator(cross_model)
# 0.003316821814673943
โช๏ธ ๋ฏธ์ธ์กฐ์ ํ์ง ์์ ๊ต์ฐจ ์ธ์ฝ๋์ ์ฑ๋ฅ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด, ๋ฌธ์ฅ์ ๊ด๋ จ์ฑ์ ์ ๊ณ์ฐํ์ง ๋ชปํ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
3) ๊ต์ฐจ์ธ์ฝ๋ ํ์ต ์ํ
# ํ์ต ๋ฐ์ดํฐ์
์ค๋น
train_samples = []
for idx, row in df_train_ir.iterrows():
train_samples.append(InputExample(
texts=[row['question'], row['context']], label=1
))
train_samples.append(InputExample(
texts=[row['question'], row['irrelevant_context']], label=0
))
train_batch_size = 16
num_epochs = 1
model_save_path = 'output/training_mrc'
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
cross_model.fit(
train_dataloader=train_dataloader,
epochs=num_epochs,
warmup_steps=100,
output_path=model_save_path
)
# ๊ฒฐ๊ณผ ํ๊ฐ
ce_evaluator(cross_model)
# 0.8650250798639563
โช๏ธ ํ์ต ์ํ ํ, ์ฑ๋ฅ์ ๋ค์ ํ์ธํด๋ณด๋ฉด 0.865๋ก ํฌ๊ฒ ๋์์ง ๊ฒ์ ํ์ธํ ์ ์๋ค.
5. ๋ฐ์ด์ธ์ฝ๋์ ๊ต์ฐจ ์ธ์ฝ๋๋ก ๊ฐ์ ๋ RAG๊ตฌํํ๊ธฐ
โ 3๊ฐ ๋ชจ๋ธ ํ์ต
โช๏ธ 1) ์ธ์ด๋ชจ๋ธ์ ์๋ฒ ๋ฉ ๋ชจ๋ธ (์ฌ์ ํ์ต๋ ์ธ์ด๋ชจ๋ธ+Pooling์ธต) ๋ก ๋ณํํ ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ๋ชจ๋ธ
โช๏ธ 2) ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ MRC ๋ฐ์ดํฐ์ ์ผ๋ก ๋ฏธ์ธ์กฐ์ ํ ์๋ฒ ๋ฉ ๋ชจ๋ธ
โช๏ธ 3) MRC ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต์ํจ ๊ต์ฐจ์ธ์ฝ๋
โ ์ฑ๋ฅ์งํ
โช๏ธ HitRate@10 : ์ง๋ฌธ ์นผ๋ผ์ ์ ๋ ฅํ์ ๋, ๊ฒ์๋ ์์ 10๊ฐ ๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์ ๋ต์ด ์๋ ๋น์จ
โช๏ธ ๊ด๋ จํด evaluate_hit_rate ํจ์๋ฅผ ์ ์ํจ
5.1 ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ๋ชจ๋ธ๋ก ๊ฒ์
from sentence_transformers import SentenceTransformer
base_embedding_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts')
base_index = make_embedding_index(base_embedding_model, klue_mrc_test['context'])
evaluate_hit_rate(klue_mrc_test, base_embedding_model, base_index, 10)
# (0.88, 13.216430425643921)
โช๏ธ 88%์ ๋ฐ์ดํฐ์์ ์ ๋ต์ ์ ์ฐพ์๊ณ , ํ๊ฐ์๋ 13์ด๊ฐ ๊ฑธ๋ ธ๋ค. ์ด 1000๊ฐ์ ํ๊ฐ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ผ๋ฏ๋ก ๋ฐ์ดํฐ ํ๋๋น 0.013์ด๊ฐ ์์๋จ
5.2 ๋ฏธ์ธ์กฐ์ ํ ์๋ฒ ๋ฉ ๋ชจ๋ธ๋ก ๊ฒ์
finetuned_embedding_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts-mrc')
finetuned_index = make_embedding_index(finetuned_embedding_model, klue_mrc_test['context'])
evaluate_hit_rate(klue_mrc_test, finetuned_embedding_model, finetuned_index, 10)
# (0.946, 14.309881687164307)
โช๏ธ 94.6%์ ๋ฐ์ดํฐ์์ ์ ๋ต ๊ธฐ์ฌ๋ฅผ ์ ํํ ๊ฐ์ ธ์๋ค. ์ฝ๊ฐ์ ๋ฏธ์ธ์กฐ์ ๋ง์ผ๋ก๋ 7.5%์ ๋ ์ฑ๋ฅ์ด ํฅ์๋์๋ค.
5.3 ๋ฏธ์ธ์กฐ์ ํ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๊ต์ฐจ ์ธ์ฝ๋ ์กฐํฉํ๊ธฐ
โ ๋ฐ์ด์ธ์ฝ๋ + ๊ต์ฐจ์ธ์ฝ๋
โช๏ธ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์์ N ๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์, ๊ต์ฐจ์ธ์ฝ๋๊ฐ ์์๋ฅผ ์ ์ฌ๋์์ผ๋ก ์ฌ์ ๋ ฌํ ํ ์์ K๊ฐ๋ฅผ ์ถ์ถํ๋ฉด, ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํตํด ์ ์ฌ๋๊ฐ ๋์ ์์ K๊ฐ๋ฅผ ๋ฐ๋ก ๋ฝ์์ ๋๋ณด๋ค ์ฑ๋ฅ์ ๋์ผ ์ ์๋ค.
โช๏ธ ๊ต์ฐจ์ธ์ฝ๋์ ๊ฒฝ์ฐ, ์๋๊ฐ ๋๋ฆฌ๋ฏ๋ก ์ ์ฒด ๋ฌธ์๋ฅผ ๊ฒ์ ๋์์ผ๋ก ํ์ง ์๊ณ ์์ N๊ฐ๋ง์ ๋์์ผ๋ก ๊ณ์ฐํ๋๋ก ๋ฒ์๋ฅผ ์ขํ๋ค.
import time
import numpy as np
from tqdm.auto import tqdm
# ์์ ์ฌ์ ๋ ฌ์ ํฌํจํ ํ๊ฐํจ์
def evaluate_hit_rate_with_rerank(datasets, embedding_model, cross_model, index, bi_k=30, cross_k=10):
start_time = time.time()
predictions = []
for question_idx, question in enumerate(tqdm(datasets['question'])):
indices = find_embedding_top_k(question, embedding_model, index, bi_k)[0]
predictions.append(rerank_top_k(cross_model, question_idx, indices, k=cross_k))
total_prediction_count = len(predictions)
hit_count = 0
questions = datasets['question']
contexts = datasets['context']
for idx, prediction in enumerate(predictions):
for pred in prediction:
if contexts[pred] == contexts[idx]:
hit_count += 1
break
end_time = time.time()
return hit_count / total_prediction_count, end_time - start_time, predictions
# ๊ฒฐ๊ณผ
hit_rate, cosumed_time, predictions = evaluate_hit_rate_with_rerank(klue_mrc_test, finetuned_embedding_model, cross_model, finetuned_index, bi_k=30, cross_k=10)
hit_rate, cosumed_time
# (0.973, 1103.055629491806)
โช๏ธ 97.3% ์ ๋ฐ์ดํฐ์์ ์ ๋ต์ ์ ์ฐพ์ ๊ฒ์ ๋ณผ ์ ์๋ค.
'1๏ธโฃ AIโขDS > ๐ LLM' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
| 13. LLM ์ด์ํ๊ธฐ (0) | 2025.10.27 |
|---|---|
| 12. ๋ฒกํฐ๋ฐ์ดํฐ๋ฒ ์ด์ค๋ก ํ์ฅํ๊ธฐ : RAG ๊ตฌํํ๊ธฐ (0) | 2025.10.22 |
| [์ฑ ์คํฐ๋] 10-(2). ์ค์ต : ์๋ฏธ๊ฒ์ ๊ตฌํํ๊ธฐ (0) | 2025.09.19 |
| [์ฑ ์คํฐ๋] 10-(1). ์๋ฒ ๋ฉ ๋ชจ๋ธ๋ก ๋ฐ์ดํฐ ์๋ฏธ ์์ถํ๊ธฐ (0) | 2025.09.18 |
| [์ฑ ์คํฐ๋] 9. LLM ์ ํ๋ฆฌ์ผ์ด์ ๊ฐ๋ฐํ๊ธฐ (1) | 2025.09.08 |
๋๊ธ