๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
1๏ธโƒฃ AI•DS/๐ŸŒ LLM

11. ์ž์‹ ์˜ ๋ฐ์ดํ„ฐ์— ๋งž์ถ˜ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋งŒ๋“ค๊ธฐ : RAG ๊ฐœ์„ ํ•˜๊ธฐ

by isdawell 2025. 10. 19.
728x90

 

 

โ˜€๏ธ  Summary 

โ  ๋ฐ”์ด์ธ์ฝ”๋” + ๊ต์ฐจ์ธ์ฝ”๋” ๊ฒฐํ•ฉ์„ ํ†ตํ•œ RAG ์„ฑ๋Šฅ ํ–ฅ์ƒ 
โ  ์‚ฌ์ „ํ•™์Šต๋œ ์–ธ์–ด๋ชจ๋ธ + Pooling์ธต + ํŒŒ์ธํŠœ๋‹ 

 

 

 

 

1.   ๊ฒ€์ƒ‰ ์„ฑ๋Šฅ์„ ๋†’์ด๊ธฐ ์œ„ํ•œ ๋‘ ๊ฐ€์ง€ ๋ฐฉ๋ฒ• 


 

โ  ๋ฐ”์ด์ธ์ฝ”๋”์™€ ๊ต์ฐจ์ธ์ฝ”๋”๋ฅผ ๊ฒฐํ•ฉํ•ด ๋” ๋น ๋ฅธ ๊ฒ€์ƒ‰์„ ์ˆ˜ํ–‰ 

 โ†ช๏ธŽ  1) ๋ฐ”์ด์ธ์ฝ”๋”๋ฅผ ์‚ฌ์šฉํ•ด ๋Œ€๊ทœ๋ชจ์˜ ๋ฌธ์„œ์—์„œ ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ์™€ ์œ ์‚ฌํ•œ ์†Œ์ˆ˜์˜ ๋ฌธ์„œ(ex.์ƒ์œ„ 100๊ฐœ)๋ฅผ ์„ ๋ณ„ 

 โ†ช๏ธŽ  2) ์˜๋ฏธ๊ฒ€์ƒ‰์„ ํ†ตํ•ด ์„ ๋ณ„ํ•œ ์†Œ์ˆ˜์˜ ๋ฌธ์„œ๋Š” ์œ ์‚ฌ๋„๋ฅผ ๋” ์ •ํ™•ํžˆ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ต์ฐจ ์ธ์ฝ”๋”๋ฅผ ์‚ฌ์šฉํ•ด ์œ ์‚ฌํ•œ ์ˆœ์„œ๋Œ€๋กœ ์žฌ์ •๋ ฌ (Rerank) 

 

https://cohere.com/blog/rerank

 

 

โ  ๋ฐ”์ด์ธ์ฝ”๋”์™€ ๊ต์ฐจ์ธ์ฝ”๋”๋ฅผ ๊ฒฐํ•ฉํ•  ๋•Œ, ๊ฒ€์ƒ‰ ์„ฑ๋Šฅ์„ ๋” ๋†’์ด๋Š” ๋ฐฉ๋ฒ•

 โ†ช๏ธŽ  1) ๋ฐ”์ด์ธ์ฝ”๋” ์ถ”๊ฐ€ ํ•™์Šต : ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋„ ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ์œ ์‚ฌํ•œ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์ž˜ ์ž‘๋™ํ•˜๋ฏ€๋กœ, ์‚ฌ์šฉํ•˜๋ ค๋Š” ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์ถ”๊ฐ€ํ•™์Šต์„ ์‹œํ‚จ๋‹ค. 

 โ†ช๏ธŽ  2) ๊ต์ฐจ์ธ์ฝ”๋” ์ถ”๊ฐ€ ํ•™์Šต :  ๊ฒ€์ƒ‰๋œ ๋ชจ๋“  ๋ฌธ์„œ๊ฐ€ ์•„๋‹ˆ๋ผ, ์ƒ์œ„ ๋ช‡ ๊ฐœ์˜ ์ž…๋ ฅ๋งŒ ํ”„๋กฌํ”„ํŠธ์— ์ถ”๊ฐ€ํ•ด ๊ฒ€์ƒ‰ ์ฆ๊ฐ• ์ƒ์„ฑ์ด ํšจ๊ณผ์ ์œผ๋กœ ์ž‘๋™ํ•˜๋„๋ก ํ•จ 

 

 

 

 


2.   ์–ธ์–ด๋ชจ๋ธ์„ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋กœ ๋งŒ๋“ค๊ธฐ 


 

 

โ†ช๏ธŽ  ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์€ 2๊ฐœ์˜ ์ธต์œผ๋กœ ๋‚˜๋‰œ๋‹ค. ์ฒซ ๋ฒˆ์งธ ์ธต์€ ๋Œ€๋Ÿ‰์˜ ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์ „ํ•™์Šตํ•œ BERT๋‚˜ RoBERTa ๊ฐ™์€ ์–ธ์–ด๋ชจ๋ธ์ด๋‹ค. ๋‘๋ฒˆ์งธ ์ธต์€ ํ’€๋ง์ธต์œผ๋กœ ์ž…๋ ฅ ๋ฌธ์žฅ์˜ ๊ธธ์ด์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ๋Š” ์ถœ๋ ฅ ์ฐจ์›์„ ๊ณ ์ •๋œ ์ฐจ์›์œผ๋กœ ๋งž์ถ”๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. ํ’€๋ง์˜ ๋ฐฉ์‹์—๋Š” ํด๋ž˜์Šค๋ชจ๋“œ, ํ‰๊ท ๋ชจ๋“œ, ์ตœ๋Œ€๋ชจ๋“œ๊ฐ€ ์žˆ๋Š”๋ฐ ์ผ๋ฐ˜์ ์œผ๋กœ ํ‰๊ท  ๋ชจ๋“œ๋ฅผ ๋งŽ์ด ์‚ฌ์šฉํ•œ๋‹ค. ์ด์ฒ˜๋Ÿผ ์‚ฌ์ „ํ•™์Šต๋œ ์–ธ์–ด๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ๊ทธ ์œ„์— ํ’€๋ง ์ธต์„ ์ถ”๊ฐ€ํ•˜๊ณ  ๋ฌธ์žฅ์˜ ์˜๋ฏธ๋ฅผ ์ž˜ ๋‹ด์„ ์ˆ˜ ์žˆ๋„๋ก ํ•™์Šตํ•ด์•ผ ํ•œ๋‹ค. 

โ†ช๏ธŽ  Sentence-Tranformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ์‰ฝ๊ฒŒ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

 

2.1  ๋Œ€์กฐํ•™์Šต 

 

โ  Contrastive learning

 โ†ช๏ธŽ  ๊ด€๋ จ์ด ์žˆ๊ฑฐ๋‚˜ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋Š” ๋” ๊ฐ€๊นŒ์›Œ์ง€๋„๋ก ๋งŒ๋“ค๊ณ  ๊ด€๋ จ์ด ์—†๊ฑฐ๋‚˜ ์œ ์‚ฌํ•˜์ง€ ์•Š์€ ๋ฐ์ดํ„ฐ๋Š” ๋” ๋ฉ€์–ด์ง€๋„๋ก ํ•˜๋Š” ํ•™์Šต ๋ฐฉ์‹

 โ†ช๏ธŽ  ๋Œ€์กฐํ•™์Šต ๋ฐฉ์‹์œผ๋กœ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ฌ ๋•Œ, ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. 2๊ฐœ์˜ ๋ฌธ์žฅ์„ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์— ๊ฐ๊ฐ ์ž…๋ ฅํ•˜๊ณ  ์„œ๋กœ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ์ธ ๊ฒฝ์šฐ๋Š” ๊ฐ€๊น๊ฒŒ ๊ทธ๋ ‡์ง€ ์•Š์€ ๊ฒฝ์šฐ๋Š” ๋ฉ€๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. ๋˜๋Š” ์„œ๋กœ ์ด์–ด์ง€๋Š” ๋ฌธ์žฅ์ด๋ผ๋ฉด ๊ฐ€๊น๊ฒŒ ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์„œ๋กœ ๋ฉ€๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. ๋งˆ์ง€๋ง‰์œผ๋กœ ์„œ๋กœ ์งˆ๋ฌธ๋‹ต๋ณ€ ๊ด€๊ณ„์ธ ๊ฒฝ์šฐ๋ผ๋ฉด ๊ฐ€๊น๊ฒŒ ์•„๋‹ˆ๋ฉด ๋ฉ€๊ฒŒ ํ•™์Šต์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

2.2  ์‹ค์Šต [1] 

 

โ  ์–ธ์–ด๋ชจ๋ธ์„ ๊ทธ๋Œ€๋กœ ๋ถˆ๋Ÿฌ์™€ ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ์„ ๋งŒ๋“ค์–ด๋ณด๊ธฐ 

 

1) ์‚ฌ์ „ ํ•™์Šต๋œ ์–ธ์–ด๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€ ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋งŒ๋“ค๊ธฐ 

 

from sentence_transformers import SentenceTransformer, models

# 1) modules ๋ชจ๋“ˆ์„ ํ™œ์šฉํ•ด ์‚ฌ์ „ํ•™์Šต๋œ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ 
transformer_model = models.Transformer('klue/roberta-base')

# 2) ํ‰๊ท  ํ’€๋ง์ธต ์ƒ์„ฑ 
pooling_layer = models.Pooling(
    transformer_model.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

# 3) SentenceTransformer ํด๋ž˜์Šค๋กœ ๋‘ ๋ชจ๋“ˆ์„ ๊ฒฐํ•ฉํ•ด ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ƒ์„ฑ 
embedding_model = SentenceTransformer(modules=[transformer_model, pooling_layer])

 

 

 

 

2) ์‹ค์Šต ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

 โ†ช๏ธŽ  KLUE์˜ STS ๋ฐ์ดํ„ฐ์…‹ (๋‘ ๋ฌธ์žฅ์ด ์„œ๋กœ ์–ผ๋งˆ๋‚˜ ์œ ์‚ฌํ•œ์ง€ ์ ์ˆ˜๋ฅผ ๋งค๊ธด ๋ฐ์ดํ„ฐ์…‹) 

 โ†ช๏ธŽ  labels ์นผ๋Ÿผ์— ๋‘ ๋ฌธ์žฅ์ด ์–ผ๋งˆ๋‚˜ ์œ ์‚ฌํ•œ์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ๋‹ค์–‘ํ•œ ํ˜•์‹์˜ ๋ ˆ์ด๋ธ”์ด ์žˆ์Œ (label์„ ์‚ฌ์šฉํ•  ์˜ˆ์ •) 

 

from datasets import load_dataset
klue_sts_train = load_dataset('klue', 'sts', split='train')
klue_sts_test = load_dataset('klue', 'sts', split='validation')
klue_sts_train[0]

# {'guid': 'klue-sts-v1_train_00000',
#  'source': 'airbnb-rtt',
#  'sentence1': '์ˆ™์†Œ ์œ„์น˜๋Š” ์ฐพ๊ธฐ ์‰ฝ๊ณ  ์ผ๋ฐ˜์ ์ธ ํ•œ๊ตญ์˜ ๋ฐ˜์ง€ํ•˜ ์ˆ™์†Œ์ž…๋‹ˆ๋‹ค.',
#  'sentence2': '์ˆ™๋ฐ•์‹œ์„ค์˜ ์œ„์น˜๋Š” ์‰ฝ๊ฒŒ ์ฐพ์„ ์ˆ˜ ์žˆ๊ณ  ํ•œ๊ตญ์˜ ๋Œ€ํ‘œ์ ์ธ ๋ฐ˜์ง€ํ•˜ ์ˆ™๋ฐ•์‹œ์„ค์ž…๋‹ˆ๋‹ค.',
#  'labels': {'label': 3.7, 'real-label': 3.714285714285714, 'binary-label': 1}}




# ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹์˜ 10%๋ฅผ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๊ตฌ์„ฑ 
klue_sts_train = klue_sts_train.train_test_split(test_size=0.1, seed=42)
klue_sts_train, klue_sts_eval = klue_sts_train['train'], klue_sts_train['test']

 

 

 

3) label ์ •๊ทœํ™”ํ•˜๊ธฐ 

 

from sentence_transformers import InputExample


# ์œ ์‚ฌ๋„ ์ ์ˆ˜๋ฅผ 0~1 ์‚ฌ์ด๋กœ ์ •๊ทœํ™” ํ•˜๊ณ  InputExample ๊ฐ์ฒด์— ๋‹ด๊ธฐ 
def prepare_sts_examples(dataset):
    examples = []
    for data in dataset:
        examples.append(
            InputExample(
                texts=[data['sentence1'], data['sentence2']], # ํ…์ŠคํŠธ์Œ์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ž…๋ ฅ
                label=data['labels']['label'] / 5.0) # ์ •๊ทœํ™” 
            )
    return examples


# ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ์ˆ˜ํ–‰ 
train_examples = prepare_sts_examples(klue_sts_train)
eval_examples = prepare_sts_examples(klue_sts_eval)
test_examples = prepare_sts_examples(klue_sts_test)

 

 

4) ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ์…‹ ๋งŒ๋“ค๊ธฐ 

 

from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

 

 

 

5) ๊ฒ€์ฆ์„ ์œ„ํ•œ ํ‰๊ฐ€ ๊ฐ์ฒด ์ค€๋น„ ๋ฐ ๋ชจ๋ธ ์„ฑ๋Šฅ ๊ฒฐ๊ณผ 

 โ†ช๏ธŽ  EmbeddingSimilarityEvaluator ๋ฅผ ์‚ฌ์šฉํ•ด ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•  ๋•Œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ์ค€๋น„ 

 โ†ช๏ธŽ  from_input_examples ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด, ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹๊ณผ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๋Š” ํ‰๊ฐ€ ๊ฐ์ฒด๋ฅผ ์ƒ์„ฑ 

 

from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

eval_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_examples)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_examples)


test_evaluator(embedding_model)
# 0.36460670798564826

 

 โ†ช๏ธŽ  ์–ธ์–ด๋ชจ๋ธ์„ ๊ทธ๋Œ€๋กœ ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋กœ ๋งŒ๋“  embedding_model์ด ์–ผ๋งˆ๋‚˜ ๋ฌธ์žฅ์˜ ์˜๋ฏธ๋ฅผ ์ž˜ ๋ฐ˜์˜ํ•ด ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•˜๋Š”์ง€ ํ™•์ธํ•ด๋ณด์•˜์„ ๋•Œ, 0.364๋กœ ์—ญํ• ์„ ์ž˜ํ•˜์ง€ ๋ชปํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

 

2.3  ์‹ค์Šต [2] ์œ ์‚ฌํ•œ ๋ฌธ์žฅ ๋ฐ์ดํ„ฐ๋กœ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ํ•™์Šตํ•˜๊ธฐ 

 

โ  ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ํ•™์Šต

 โ†ช๏ธŽ  ๊ธฐ๋ณธ ์–ธ์–ด๋ชจ๋ธ :  klue/roberta-base

 โ†ช๏ธŽ  ํ•™์Šต์— ์‚ฌ์šฉ๋  ์†์‹คํ•จ์ˆ˜ : CosineSimilarityLoss๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๋‘ ๋ฌธ์žฅ ์‚ฌ์ด์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„์™€ ์ •๋‹ต ์œ ์‚ฌ๋„๋ฅผ ๋น„๊ตํ•ด ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. 

 

from sentence_transformers import losses

num_epochs = 4
model_name = 'klue/roberta-base'
model_save_path = 'output/training_sts_' + model_name.replace("/", "-")
train_loss = losses.CosineSimilarityLoss(model=embedding_model)

# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ํ•™์Šต
embedding_model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=eval_evaluator,
    epochs=num_epochs,
    evaluation_steps=1000,
    warmup_steps=100,
    output_path=model_save_path
)

 

 

 

โ  ํ•™์Šตํ•œ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅํ‰๊ฐ€ 

 โ†ช๏ธŽ  ํ•™์Šต๋œ ๋ชจ๋ธ์ด ์ €์žฅ๋œ ๊ฒฝ๋กœ์—์„œ ๋ชจ๋ธ์„ ์ฝ์–ด์˜ค๊ณ , ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•œ๋‹ค. ์ ์ˆ˜๊ฐ€ 0.364์—์„œ 0.896์œผ๋กœ ํฌ๊ฒŒ ํ–ฅ์ƒ๋œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

trained_embedding_model = SentenceTransformer(model_save_path)
test_evaluator(trained_embedding_model)
# 0.8965595666246748

 

 

 

โ  ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์— ๋ชจ๋ธ ์ €์žฅ 

 โ†ช๏ธŽ  ๊ณ„์ • ํ† ํฐ์„ ํ†ตํ•ด ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์— ์ ‘๊ทผํ•ด, ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•œ๋‹ค. 

from huggingface_hub import login
from huggingface_hub import HfApi

login(token='ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ ํ† ํฐ ์ž…๋ ฅ')
api = HfApi()
repo_id="klue-roberta-base-klue-sts"
api.create_repo(repo_id=repo_id)


# ๋ชจ๋ธ ์—…๋กœ๋“œ 
api.upload_folder(
    folder_path=model_save_path,
    repo_id=f"๋ณธ์ธ์˜ ํ—ˆ๊น…ํŽ˜์ด์Šค ์•„์ด๋”” ์ž…๋ ฅ/{repo_id}",
    repo_type="model",
)

 

 

 

 


3.   ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ 


 

3.1  ํ•™์Šต ์ค€๋น„

 

โ [๋ณต์Šต] RAG

 โ†ช๏ธŽ  RAG๋Š” ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ์™€ ๊ด€๋ จ๋œ ๋ฌธ์„œ๋ฅผ ์ฐพ์•„, LLM ํ”„๋กฌํ”„ํŠธ์— ๋งฅ๋ฝ ๋ฐ์ดํ„ฐ๋กœ ์ถ”๊ฐ€ํ•  ๋•Œ, ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ํ™œ์šฉํ•œ๋‹ค. ์ข‹์€ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์ด๋ผ๋ฉด ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ์™€ ๊ด€๋ จ์žˆ๋Š” ๋ฌธ์„œ๋Š” ์œ ์‚ฌ๋„๊ฐ€ 1์— ๊ฐ€๊น๊ฒŒ ๋‚˜์™€์•ผ ํ•œ๋‹ค. 

 โ†ช๏ธŽ  ์ž„๋ฒ ๋”ฉ๋ชจ๋ธ์„ KLUE์˜ MRC ๋ฐ์ดํ„ฐ์…‹(๊ธฐ์‚ฌ ๋ณธ๋ฌธ ๋ฐ ํ•ด๋‹น ๊ธฐ์‚ฌ์™€ ๊ด€๋ จ๋œ ์งˆ๋ฌธ์„ ์ˆ˜์ง‘ํ•œ ๋ฐ์ดํ„ฐ)์œผ๋กœ ์ถ”๊ฐ€ํ•™์Šต์‹œ์ผœ ์‹ค์Šต ๋ฐ์ดํ„ฐ์˜ ๋ฌธ์žฅ ์‚ฌ์ด์˜ ์œ ์‚ฌ๋„๋ฅผ ๋” ์ž˜ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋„๋ก ๋งŒ๋“ ๋‹ค. 

 

 

โ  ๋ฏธ์„ธ์กฐ์ • 

 โ†ช๏ธŽ  ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋„, ๋‹ค๋ฅธ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ์ผ ๋•Œ ๊ฐ€์žฅ ์ž˜ ๋™์ž‘ํ•œ๋‹ค. 

 โ†ช๏ธŽ  ๋”ฐ๋ผ์„œ, ์‚ฌ์ „ํ•™์Šต๋œ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๊ทธ๋Œ€๋กœ ํ™œ์šฉํ•˜๋Š” ๊ฒฝ์šฐ, ์‚ฌ์ „ ํ•™์Šต์— ์‚ฌ์šฉ๋œ ๋ฐ์ดํ„ฐ์…‹์ด, ์‹ค์Šต์— ์‚ฌ์šฉํ•˜๋Š” MRC ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋‹จ์–ด, ์ฃผ์ œ ๋“ฑ์ด ๋‹ค๋ฅด๋ฉด ์„ฑ๋Šฅ์ด ๋‚ฎ์•„์ง„๋‹ค. ๋”ฐ๋ผ์„œ MRC ๋ฐ์ดํ„ฐ์…‹์— ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜๋ ค๊ณ  ํ•œ๋‹ค๋ฉด, ๊ทธ ๋ชฉ์ ์— ๋งž๊ฒŒ MRC ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ฏธ์„ธ์กฐ์ • ํ•ด์•ผ ํ•œ๋‹ค. 

 

 

โ  ์‹ค์Šต  

 

1) 2์žฅ์—์„œ ์ €์žฅํ•œ ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ 

 

from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts')

 

 

2) ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ์—†๋Š” ๊ธฐ์‚ฌ ์ถ”๊ฐ€ํ•˜๊ธฐ 

 โ†ช๏ธŽ  question - context ์นผ๋Ÿผ์€ ๊ด€๋ จ์ด ์žˆ์œผ๋ฏ€๋กœ label 1์„ ๋ถ€์—ฌํ•˜๊ณ , ์ž„์˜๋กœ ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ์—†๋Š” ๊ธฐ์‚ฌ๋ฅผ ์ถ”๊ฐ€ํ•ด ๋งŒ๋“  irrelevant_context ์นผ๋Ÿผ์€ ๊ด€๋ จ์ด ์—†์œผ๋ฏ€๋กœ label 0์„ ๋ถ€์—ฌ

 

def add_ir_context(df):
  irrelevant_contexts = []
  for idx, row in df.iterrows():
    title = row['title']
    irrelevant_contexts.append(df.query(f"title != '{title}'").sample(n=1)['context'].values[0])
  df['irrelevant_context'] = irrelevant_contexts
  return df

df_train_ir = add_ir_context(df_train)

 

 

 

3) ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋ฏธ์„ธ์กฐ์ •ํ•˜์ง€ ์•Š์€ ์ƒํƒœ์—์„œ ์„ฑ๋Šฅ ํ‰๊ฐ€ 

 

from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
    examples
)
evaluator(sentence_model)
# 0.8151553052035344

 

 

 

 

3.2  ๋ฏธ์„ธ์กฐ์ • 

 

โ  MNR์†์‹ค์„ ํ™œ์šฉํ•ด ๋ฏธ์„ธ์กฐ์ •ํ•˜๊ธฐ 

 โ†ช๏ธŽ  ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋งŒ๋“ค ๋•Œ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ์†์‹ค์„ ํ™œ์šฉํ•ด ๋ชจ๋ธ์„ ํ•™์Šต์‹œ์ผฐ์—ˆ๋Š”๋ฐ, ์ด๋ฒˆ์—๋Š” Multiple Negatives Ranking ์†์‹ค์„ ์‚ฌ์šฉํ•ด ๋ฏธ์„ธ์กฐ์ •์„ ์‹œ๋„ํ•œ๋‹ค. 

 โ†ช๏ธŽ  MNR์€ MRC ๋ฐ์ดํ„ฐ์…‹๊ณผ ๊ฐ™์ด ์„œ๋กœ ๊ด€๋ จ์ด ์žˆ๋Š” ๋ฌธ์žฅ๋งŒ ์žˆ๋Š” ๊ฒฝ์šฐ ์‚ฌ์šฉํ•˜๊ธฐ ์ข‹์€ ์†์‹คํ•จ์ˆ˜์ด๋‹ค. MNR ์†์‹ค์„ ์‚ฌ์šฉํ•˜๋ฉด, ์ž๋™์œผ๋กœ ํ•˜๋‚˜์˜ ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ ์•ˆ์—์„œ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์˜ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ๊ด€๋ จ์—†๋Š” ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๊ธฐ ๋•Œ๋ฌธ์—, ์„œ๋กœ ๊ด€๋ จ์ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋งŒ์œผ๋กœ ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๊ตฌ์„ฑํ•˜๋ฉด ๋œ๋‹ค. 

 

1) ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ 

# ๊ธ์ • ๋ฐ์ดํ„ฐ๋งŒ์œผ๋กœ ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ 
train_samples = []
for idx, row in df_train_ir.iterrows():
    train_samples.append(InputExample(
        texts=[row['question'], row['context']]
    ))
    
    
 # ์ค‘๋ณต ํ•™์Šต ๋ฐ์ดํ„ฐ ์ œ๊ฑฐ 
from sentence_transformers import datasets

batch_size = 16

loader = datasets.NoDuplicatesDataLoader(
    train_samples, batch_size=batch_size)

 

 

2) MNR ์†์‹คํ•จ์ˆ˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ๋ฐ ๋ฏธ์„ธ์กฐ์ • 

 

from sentence_transformers import losses
#[์ฐธ๊ณ ] sentence_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts')

# ์†์‹คํ•จ์ˆ˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ 
loss = losses.MultipleNegativesRankingLoss(sentence_model)

# ๋ฏธ์„ธ์กฐ์ • ์ˆ˜ํ–‰ 
epochs = 1
save_path = './klue_mrc_mnr'

sentence_model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=100,
    output_path=save_path,
    show_progress_bar=True
)

 

 

3) ํ‰๊ฐ€ 

 โ†ช๏ธŽ  ๋ฏธ์„ธ์กฐ์ • ์ „์— 0.815์˜€๋˜ ์„ฑ๋Šฅ์ด 0.86์œผ๋กœ ์ƒ์Šน 

evaluator(sentence_model)
# 0.8600968992433692

 

 

 

4) ๋ชจ๋ธ ์—…๋กœ๋“œ

from huggingface_hub import HfApi
api = HfApi()
repo_id = "klue-roberta-base-klue-sts-mrc"
api.create_repo(repo_id=repo_id)

api.upload_folder(
    folder_path=save_path,
    repo_id=f"๋ณธ์ธ์˜ ์•„์ด๋”” ์ž…๋ ฅ/{repo_id}",
    repo_type="model",
)

 

 

 

 

 

4.  ๊ฒ€์ƒ‰ ํ’ˆ์งˆ์„ ๋†’์ด๋Š” ์ˆœ์œ„ ์žฌ์ •๋ ฌ 


 

โ  ๊ต์ฐจ์ธ์ฝ”๋” ๋ฏธ์„ธ์กฐ์ •

 โ†ช๏ธŽ  ๊ต์ฐจ์ธ์ฝ”๋”๋Š” 2๊ฐœ ๋ฌธ์žฅ์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ฌธ์žฅ ์‚ฌ์ด์˜ ๊ด€๊ณ„๋ฅผ ํ•™์Šตํ•˜๋ฏ€๋กœ, ๋ฌธ์žฅ๋ถ„๋ฅ˜๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ๋‹ค. 

 โ†ช๏ธŽ  ๋ฌธ์žฅ๋ถ„๋ฅ˜ ๋ชจ๋ธ์ด๋ผ, transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ ๋ชจ๋ธ์„ ์ง์ ‘ ํ•™์Šตํ•˜๋Š” ๋ฐฉ์‹๋„ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, ์‹ค์Šต์—์„œ๋Š” CrossEncoder์™€ ๋ฏธ์„ธ์กฐ์ • ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. 

 

 

1) CrossEncoder ๋ถˆ๋Ÿฌ์˜ค๊ธฐ 

from sentence_transformers.cross_encoder import CrossEncoder
cross_model = CrossEncoder('klue/roberta-small', num_labels=1)

 

 โ†ช๏ธŽ  ๊ต์ฐจ์ธ์ฝ”๋”๋Š” ๋งŽ์€ ๊ณ„์‚ฐ์„ ํ•ด์•ผ ํ•˜๋ฏ€๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ์ˆ˜๊ฐ€ ์ž‘์€ roberta-small ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ๋‹ค. 

 โ†ช๏ธŽ  roberta-small ๋ชจ๋ธ์€ ๋ถ„๋ฅ˜ํ—ค๋“œ๊ฐ€ ์—†๋Š” ์–ธ์–ด๋ชจ๋ธ์ด๋ฏ€๋กœ, ๊ต์ฐจ์ธ์ฝ”๋”๋กœ ๋ถˆ๋Ÿฌ์˜ค๋ ค๋ฉด ๋ถ„๋ฅ˜ ํ—ค๋“œ๋Š” ๋žœ๋ค์œผ๋กœ ์ดˆ๊ธฐํ™”๋œ๋‹ค(์„ฑ๋Šฅ์ด ๋‚ฎ์„์ˆ˜๋ฐ–์— ์—†์Œ). 

 

 

2) ์ดˆ๊ธฐ ์„ฑ๋Šฅ ๊ฒฐ๊ณผ

from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
ce_evaluator = CECorrelationEvaluator.from_input_examples(examples)
ce_evaluator(cross_model)
# 0.003316821814673943

 

 โ†ช๏ธŽ  ๋ฏธ์„ธ์กฐ์ •ํ•˜์ง€ ์•Š์€ ๊ต์ฐจ ์ธ์ฝ”๋”์˜ ์„ฑ๋Šฅ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด, ๋ฌธ์žฅ์˜ ๊ด€๋ จ์„ฑ์„ ์ž˜ ๊ณ„์‚ฐํ•˜์ง€ ๋ชปํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

 

3) ๊ต์ฐจ์ธ์ฝ”๋” ํ•™์Šต ์ˆ˜ํ–‰ 

# ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ 

train_samples = []
for idx, row in df_train_ir.iterrows():
    train_samples.append(InputExample(
        texts=[row['question'], row['context']], label=1
    ))
    train_samples.append(InputExample(
        texts=[row['question'], row['irrelevant_context']], label=0
    ))

 

train_batch_size = 16
num_epochs = 1
model_save_path = 'output/training_mrc'

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

cross_model.fit(
    train_dataloader=train_dataloader,
    epochs=num_epochs,
    warmup_steps=100,
    output_path=model_save_path
)


# ๊ฒฐ๊ณผ ํ‰๊ฐ€ 
ce_evaluator(cross_model)
# 0.8650250798639563

 

 โ†ช๏ธŽ  ํ•™์Šต ์ˆ˜ํ–‰ ํ›„, ์„ฑ๋Šฅ์„ ๋‹ค์‹œ ํ™•์ธํ•ด๋ณด๋ฉด 0.865๋กœ ํฌ๊ฒŒ ๋†’์•„์ง„ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

 

5.  ๋ฐ”์ด์ธ์ฝ”๋”์™€ ๊ต์ฐจ ์ธ์ฝ”๋”๋กœ ๊ฐœ์„ ๋œ RAG๊ตฌํ˜„ํ•˜๊ธฐ 


 

โ  3๊ฐœ ๋ชจ๋ธ ํ•™์Šต

 โ†ช๏ธŽ  1) ์–ธ์–ด๋ชจ๋ธ์„ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ (์‚ฌ์ „ํ•™์Šต๋œ ์–ธ์–ด๋ชจ๋ธ+Pooling์ธต) ๋กœ ๋ณ€ํ™˜ํ•œ ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ

 โ†ช๏ธŽ  2) ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ MRC ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ฏธ์„ธ์กฐ์ •ํ•œ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ 

 โ†ช๏ธŽ  3) MRC ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šต์‹œํ‚จ ๊ต์ฐจ์ธ์ฝ”๋” 

 

โ  ์„ฑ๋Šฅ์ง€ํ‘œ

 โ†ช๏ธŽ  HitRate@10 : ์งˆ๋ฌธ ์นผ๋Ÿผ์„ ์ž…๋ ฅํ–ˆ์„ ๋•Œ, ๊ฒ€์ƒ‰๋œ ์ƒ์œ„ 10๊ฐœ ๊ธฐ์‚ฌ ๋ณธ๋ฌธ์— ์ •๋‹ต์ด ์žˆ๋Š” ๋น„์œจ 

 โ†ช๏ธŽ  ๊ด€๋ จํ•ด evaluate_hit_rate ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•จ 

 

 

5.1  ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋กœ ๊ฒ€์ƒ‰ 

 

from sentence_transformers import SentenceTransformer
base_embedding_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts')
base_index = make_embedding_index(base_embedding_model, klue_mrc_test['context'])

evaluate_hit_rate(klue_mrc_test, base_embedding_model, base_index, 10)
# (0.88, 13.216430425643921)

 

 โ†ช๏ธŽ  88%์˜ ๋ฐ์ดํ„ฐ์—์„œ ์ •๋‹ต์„ ์ž˜ ์ฐพ์•˜๊ณ , ํ‰๊ฐ€์—๋Š” 13์ดˆ๊ฐ€ ๊ฑธ๋ ธ๋‹ค. ์ด 1000๊ฐœ์˜ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ–ˆ์œผ๋ฏ€๋กœ ๋ฐ์ดํ„ฐ ํ•˜๋‚˜๋‹น 0.013์ดˆ๊ฐ€ ์†Œ์š”๋จ 

 

 

 

5.2  ๋ฏธ์„ธ์กฐ์ •ํ•œ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋กœ ๊ฒ€์ƒ‰ 

 

finetuned_embedding_model = SentenceTransformer('shangrilar/klue-roberta-base-klue-sts-mrc')
finetuned_index = make_embedding_index(finetuned_embedding_model, klue_mrc_test['context'])
evaluate_hit_rate(klue_mrc_test, finetuned_embedding_model, finetuned_index, 10)
# (0.946, 14.309881687164307)

 

 โ†ช๏ธŽ  94.6%์˜ ๋ฐ์ดํ„ฐ์—์„œ ์ •๋‹ต ๊ธฐ์‚ฌ๋ฅผ ์ •ํ™•ํžˆ ๊ฐ€์ ธ์™”๋‹ค. ์•ฝ๊ฐ„์˜ ๋ฏธ์„ธ์กฐ์ •๋งŒ์œผ๋กœ๋„ 7.5%์ •๋„ ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ๋˜์—ˆ๋‹ค. 

 

 

 

 

5.3  ๋ฏธ์„ธ์กฐ์ •ํ•œ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๊ณผ ๊ต์ฐจ ์ธ์ฝ”๋” ์กฐํ•ฉํ•˜๊ธฐ

 

โ  ๋ฐ”์ด์ธ์ฝ”๋” + ๊ต์ฐจ์ธ์ฝ”๋”

 โ†ช๏ธŽ  ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์˜ ์ƒ์œ„ N ๊ฐœ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ›์•„, ๊ต์ฐจ์ธ์ฝ”๋”๊ฐ€ ์ˆœ์œ„๋ฅผ ์œ ์‚ฌ๋„์ˆœ์œผ๋กœ ์žฌ์ •๋ ฌํ•œ ํ›„ ์ƒ์œ„ K๊ฐœ๋ฅผ ์ถ”์ถœํ•˜๋ฉด, ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ํ†ตํ•ด ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ ์ƒ์œ„ K๊ฐœ๋ฅผ ๋ฐ”๋กœ ๋ฝ‘์•˜์„ ๋•Œ๋ณด๋‹ค ์„ฑ๋Šฅ์„ ๋†’์ผ ์ˆ˜ ์žˆ๋‹ค. 

 โ†ช๏ธŽ  ๊ต์ฐจ์ธ์ฝ”๋”์˜ ๊ฒฝ์šฐ, ์†๋„๊ฐ€ ๋А๋ฆฌ๋ฏ€๋กœ ์ „์ฒด ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ ๋Œ€์ƒ์œผ๋กœ ํ•˜์ง€ ์•Š๊ณ  ์ƒ์œ„ N๊ฐœ๋งŒ์„ ๋Œ€์ƒ์œผ๋กœ ๊ณ„์‚ฐํ•˜๋„๋ก ๋ฒ”์œ„๋ฅผ ์ขํžŒ๋‹ค. 

 

 

import time
import numpy as np
from tqdm.auto import tqdm

# ์ˆœ์œ„ ์žฌ์ •๋ ฌ์„ ํฌํ•จํ•œ ํ‰๊ฐ€ํ•จ์ˆ˜ 
def evaluate_hit_rate_with_rerank(datasets, embedding_model, cross_model, index, bi_k=30, cross_k=10):
  start_time = time.time()
  predictions = []
  for question_idx, question in enumerate(tqdm(datasets['question'])):
    indices = find_embedding_top_k(question, embedding_model, index, bi_k)[0]
    predictions.append(rerank_top_k(cross_model, question_idx, indices, k=cross_k))
  total_prediction_count = len(predictions)
  hit_count = 0
  questions = datasets['question']
  contexts = datasets['context']
  for idx, prediction in enumerate(predictions):
    for pred in prediction:
      if contexts[pred] == contexts[idx]:
        hit_count += 1
        break
  end_time = time.time()
  return hit_count / total_prediction_count, end_time - start_time, predictions
 
 
# ๊ฒฐ๊ณผ 
hit_rate, cosumed_time, predictions = evaluate_hit_rate_with_rerank(klue_mrc_test, finetuned_embedding_model, cross_model, finetuned_index, bi_k=30, cross_k=10)
hit_rate, cosumed_time
# (0.973, 1103.055629491806)

 

 

 โ†ช๏ธŽ  97.3% ์˜ ๋ฐ์ดํ„ฐ์—์„œ ์ •๋‹ต์„ ์ž˜ ์ฐพ์€ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

728x90

๋Œ“๊ธ€