๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
1๏ธโƒฃ AI•DS/๐ŸŒ LLM

[์ฑ…์Šคํ„ฐ๋””] 6. sLLM ํ•™์Šตํ•˜๊ธฐ

by isdawell 2025. 8. 7.
728x90

 

๐Ÿ“ Text2SQL

1. ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์ถ•
2. ๋ชจ๋ธ ์„ฑ๋Šฅํ‰๊ฐ€์— ์‚ฌ์šฉํ•  ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์ถ• 

 

 

 

โ–บ  sLLM : ํŠน์ • ์ž‘์—… ๋˜๋Š” ๋„๋ฉ”์ธ์— ํŠนํ™”๋œ LLM 

 

Text2SQL

 

 

๊นƒํ—ˆ๋ธŒ ์ฝ”ํŒŒ์ผ๋Ÿฟ : https://spartacodingclub.kr/blog/github_copoilot

 

 

 

 

 

 

1.   Text2SQL ๋ฐ์ดํ„ฐ์…‹ 


 

1.1   ๋Œ€ํ‘œ์ ์ธ Text2SQL๋ฐ์ดํ„ฐ์…‹

 

โ  Text2SQL ๋ฐ์ดํ„ฐ์…‹ 

 โ†ช๏ธŽ  ํ…Œ์ด๋ธ” ๋ฐ ์นผ๋Ÿผ ์ •๋ณด (๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ •๋ณด)

 โ†ช๏ธŽ  ์š”์ฒญ์‚ฌํ•ญ (์–ด๋–ค ๋ฐ์ดํ„ฐ๋ฅผ ์ถ”์ถœํ•˜๊ณ  ์‹ถ์€์ง€)

 

โ  WikiSQL 

 โ†ช๏ธŽ  [ํ…Œ์ด๋ธ”๋ช…, ์นผ๋Ÿผ๋ช…, ์นผ๋Ÿผํ˜•์‹, ์š”์ฒญ์‚ฌํ•ญ(์ž์—ฐ์–ด), ์ •๋‹ตSQL์ฟผ๋ฆฌ๋ฌธ] ์œผ๋กœ ๊ตฌ์„ฑ๋œ ๋ฐ์ดํ„ฐ์…‹, ๋น„๊ต์  ๊ฐ„๋‹จํ•œ ์ฟผ๋ฆฌ๋ฌธ ์˜ˆ์‹œ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Œ

 

โ  Spider

 โ†ช๏ธŽ  ๋ณต์žกํ•œ SQL๋ฌธ๋„ ํฌํ•จ๋œ ๋ฐ์ดํ„ฐ์…‹ 

 

 

 

1.2   ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์…‹ 

 

โ  NL2SQL

 โ†ช๏ธŽ  ๋น…์ฟผ๋ฆฌ์™€ ํ•˜์ดํผํด๋กœ๋ฐ”X๋กœ ๊ตฌํ˜„ํ•˜๋Š” ์˜ˆ์‹œ

 

 

 

 

1.3  ํ•ฉ์„ฑ ๋ฐ์ดํ„ฐ ํ™œ์šฉ 

 

โ  ์‹ค์Šต์„ ์œ„ํ•ด ์ƒ์„ฑํ•œ ์˜ˆ์‹œ ๋ฐ์ดํ„ฐ

 โ†ช๏ธŽ  db_id : ๋™์ผํ•œ id๋ฅผ ๊ฐ–๋Š” ํ…Œ์ด๋ธ”์€ ๊ฐ™์€ ๋„๋ฉ”์ธ (ex.๊ฒŒ์ž„ - ๊ฒŒ์ž„์ธ ์ƒํ™ฉ์„ ๊ฐ€์ •ํ•˜๊ณ  ์ƒ์„ฑํ•œ ํ…Œ์ด๋ธ”) ์„ ๊ณต์œ ํ•จ

 โ†ช๏ธŽ  context : SQL์ƒ์„ฑ์— ์‚ฌ์šฉํ•  ํ…Œ์ด๋ธ” ์ •๋ณด 

 โ†ช๏ธŽ  question : ๋ฐ์ดํ„ฐ ์š”์ฒญ์‚ฌํ•ญ 

 โ†ช๏ธŽ  answer : ์š”์ฒญ์— ๋Œ€ํ•œ SQL ์ •๋‹ต 

 

 

 

 


2. ์„ฑ๋Šฅ ํ‰๊ฐ€ ํŒŒ์ดํ”„๋ผ์ธ ์ค€๋น„ํ•˜๊ธฐ  


 

๐Ÿ‘€  ์‹ค์Šต์—์„œ๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์‚ฌ์šฉ๋˜๋Š” Text2SQL ํ‰๊ฐ€๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ , GPT-4๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ, ์ƒ์„ฑ๋œ SQL์ด ์ •๋‹ต์ธ์ง€ ํŒ๋‹จํ•˜๋Š” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•œ๋‹ค. 

 

 

2.1   Text2SQL ํ‰๊ฐ€๋ฐฉ์‹

 

โ  EM๋ฐฉ์‹

 โ†ช๏ธŽ  ์ƒ์„ฑํ•œ SQL์ด ๋ฌธ์ž์—ด ๊ทธ๋Œ€๋กœ ๋™์ผํ•œ์ง€ ํ™•์ธํ•˜๋Š” ๋ฐฉ์‹

 โ†ช๏ธŽ  ์˜๋ฏธ์ƒ์œผ๋กœ ๋™์ผํ•œ SQL์ฟผ๋ฆฌ๊ฐ€ ๋‹ค์–‘ํ•˜๊ฒŒ ๋‚˜์˜ฌ ์ˆ˜ ์žˆ๋Š”๋ฐ ๋‹จ์ˆœํžˆ ๋ฌธ์ž์—ด์ด ๋‹ค๋ฅด๋ฉด ๋‹ค๋ฅด๋‹ค๊ณ  ํŒ๋‹จํ•˜๋Š” ๋ฌธ์ œ ๋ฐœ์ƒ

 

โ  EX๋ฐฉ์‹

 โ†ช๏ธŽ  ์‹คํ–‰ ์ •ํ™•๋„ ๋ฐฉ์‹์œผ๋กœ, ์ฟผ๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” DB๋ฅผ ๋งŒ๋“ค๊ณ , ํ”„๋กœ๊ทธ๋ž˜๋ฐ ๋ฐฉ์‹์œผ๋กœ SQL์ฟผ๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•ด ์ •๋‹ต๊ณผ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•˜๋Š” ๋ฐฉ์‹ 

 โ†ช๏ธŽ  DB๋ฅผ ์ถ”๊ฐ€๋กœ ์ค€๋น„ํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฒˆ๊ฑฐ๋กœ์›€ 

 

โ  LLM์„ LLM์œผ๋กœ ํ‰๊ฐ€ํ•˜์ž

 โ†ช๏ธŽ  ์ตœ๊ทผ์—๋Š” LLM์„ ํ™œ์šฉํ•ด LLM์˜ ์ƒ์„ฑ ๊ฒฐ๊ณผ๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐฉ์‹์ด ํ™œ๋ฐœํžˆ ์—ฐ๊ตฌ๋˜๊ณ  ์žˆ๋‹ค. ์ด๋ฒˆ ์‹ค์Šต์—์„œ๋„ LLM์ด ์ƒ์„ฑํ•œ SQL์ด ๋ฐ์ดํ„ฐ ์š”์ฒญ์„ ์ž˜ ํ•ด๊ฒฐํ•˜๋Š”์ง€ GPT-4๋ฅผ ์‚ฌ์šฉํ•ด ํ™•์ธํ•œ๋‹ค. 

 

 

 

2.2   ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์ถ• 

 

โ  8๊ฐœ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค(๋„๋ฉ”์ธ)์— ๋Œ€ํ•ด ํ•ฉ์„ฑ ๋ฐ์ดํ„ฐ์…‹์„ ์ƒ์„ฑ

 โ†ช๏ธŽ  7๊ฐœ๋Š” ํ•™์Šต์— ์‚ฌ์šฉ, 1๊ฐœ (๊ฒŒ์ž„๋„๋ฉ”์ธ)์€ ํ‰๊ฐ€์— ์‚ฌ์šฉ 

 โ†ช๏ธŽ  112๊ฐœ์˜ ํ‰๊ฐ€๋ฐ์ดํ„ฐ 

 

 

 

2.3  SQL์ƒ์„ฑ ํ”„๋กฌํ”„ํŠธ 

 

โ  ํ”„๋กฌํ”„ํŠธ ์ค€๋น„

 โ†ช๏ธŽ  LLM์ด SQL์„ ์ƒ์„ฑํ•˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด, ์ง€์‹œ์‚ฌํ•ญ๊ณผ ๋ฐ์ดํ„ฐ๋ฅผ ํฌํ•จํ•œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ค€๋น„ํ•ด์•ผ ํ•œ๋‹ค. make_prompt ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ๋ช…๋ น prompt์™€ DDL (ํ…Œ์ด๋ธ” ์ƒ์„ฑ๋ฌธ), ์งˆ๋ฌธ, ์ •๋‹ตSQL๋ฌธ์„ ์–ป๋Š”๋‹ค. ํ•™์Šต์— ์‚ฌ์šฉ๋˜๋Š” ํ˜•ํƒœ๋Š” SQL ์ฟผ๋ฆฌ๊ฐ€ ์กด์žฌํ•˜๋Š” ํ˜•ํƒœ๊ณ , SQL์„ ์ƒ์„ฑํ•˜๋Š” ์ž‘์—…์„ ํ•  ๋•Œ์—๋Š” query๋ฅผ ์ž…๋ ฅํ•˜์ง€ ์•Š๋Š”๋‹ค (์ •๋‹ต SQL์„ ๋„ฃ์ง€ ์•Š์Œ). 

 

def make_prompt(ddl, question, query=''):
    prompt = f"""๋‹น์‹ ์€ SQL์„ ์ƒ์„ฑํ•˜๋Š” SQL ๋ด‡์ž…๋‹ˆ๋‹ค. DDL์˜ ํ…Œ์ด๋ธ”์„ ํ™œ์šฉํ•œ Question์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” SQL ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”.

### DDL:
{ddl}

### Question:
{question}

### SQL:
{query}"""
    
    return prompt

 

 

 

2.4  GPT-4 ํ‰๊ฐ€ ํ”„๋กฌํ”„ํŠธ์™€ ์ฝ”๋“œ ์ค€๋น„ 

 

โ  GPT-4๋ฅผ ํ†ตํ•œ ํ‰๊ฐ€ ์ˆ˜ํ–‰

 โ†ช๏ธŽ  GPT-4๋ฅผ ์‚ฌ์šฉํ•ด ํ‰๊ฐ€๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ ค๊ณ  ํ•  ๋•Œ, ๋ฐ˜๋ณต์ ์œผ๋กœ API์š”์ฒญ์„ ๋ณด๋‚ด์•ผ ํ•œ๋‹ค. 

 โ†ช๏ธŽ  OpenAI๊ฐ€ ์ œ๊ณตํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜๋ฉด ์š”์ฒญ ์ œํ•œ์„ ๊ด€๋ฆฌํ•˜๋ฉด์„œ ๋น„๋™๊ธฐ์ ์œผ๋กœ ์š”์ฒญ์„ ๋ณด๋‚ด ๋ฐ˜๋ณต์ ์ธ API ์š”์ฒญ์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค. 

 โ†ช๏ธŽ  (make_requests_for_gpt_evaluation) ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹์„ ์ฝ์–ด GPT-4 API์š”์ฒญ์„ ๋ณด๋‚ผ jsonl ํŒŒ์ผ์„ ์ƒ์„ฑํ•œ๋‹ค. ์ƒ์„ฑ๋œ jsonํŒŒ์ผ์€ OpenAI๊ฐ€ ์ œ๊ณตํ•˜๋Š” (api_request_parallel_processor.py) ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ์ˆœ์ฐจ์ ์œผ๋กœ ์š”์ฒญ์„ ๋ณด๋‚ธ๋‹ค. 

 

*jsonl ํŒŒ์ผ : JSON Lines ํฌ๋งท์˜ ํŒŒ์ผ์„ ์˜๋ฏธํ•œ๋‹ค. .json๊ณผ๋Š” ๋‹ค๋ฅด๊ฒŒ, ํ•œ์ค„์— ํ•˜๋‚˜์˜ JSON ๊ฐ์ฒด๊ฐ€ ์žˆ๋Š” ๊ตฌ์กฐ์ด๋‹ค. 

 

 

 

โ  ํ‰๊ฐ€๋ฅผ ์œ„ํ•œ ์š”์ฒญ jsonl ์ž‘์„ฑ ํ•จ์ˆ˜ 

 โ†ช๏ธŽ  LLM์— ์งˆ๋ฌธ์„ "๋ฐฐ์น˜(batch)"๋กœ ํ‰๊ฐ€ ์š”์ฒญํ•˜๋ ค๋ฉด, ์งˆ๋ฌธ๋“ค์„ ์ค„ ๋‹จ์œ„(JSONL ํ˜•์‹)๋กœ ์ •๋ฆฌํ•ด์•ผ ํ•˜๋Š”๋ฐ, jsonl ๊ฐ์ฒด์˜ ํ•œ ์ค„์ด ํ•œ ์งˆ๋ฌธ ํ‰๊ฐ€์š”์ฒญ์ด ๋˜๋„๋ก ๋ฐ์ดํ„ฐ๋ฅผ ๊ตฌ์„ฑํ•œ๋‹ค. 

 

import json
import pandas as pd
from pathlib import Path

def make_requests_for_gpt_evaluation(df, filename, dir='requests'):
  if not Path(dir).exists():
      Path(dir).mkdir(parents=True)
  prompts = []
  for idx, row in df.iterrows():
      prompts.append("""Based on below DDL and Question, evaluate gen_sql can resolve Question. If gen_sql and gt_sql do equal job, return "yes" else return "no". Output JSON Format: {"resolve_yn": ""}""" + f"""

DDL: {row['context']}
Question: {row['question']}
gt_sql: {row['answer']}
gen_sql: {row['gen_sql']}"""
)

  # ๊ฐ ํ”„๋กฌํ”„ํŠธ๋ฅผ GPT-4 ์š”์ฒญ์— ์ ํ•ฉํ•œ ์ค„ ๋‹จ์œ„ JSON ํ˜•์‹์œผ๋กœ ์ •๋ฆฌ 
  jobs = [{"model": "gpt-4-turbo-preview", "response_format" : { "type": "json_object" }, "messages": [{"role": "system", "content": prompt}]} for prompt in prompts]
  
  # jsonl ํ˜•์‹์œผ๋กœ ์ €์žฅ 
  with open(Path(dir, filename), "w") as f:
      for job in jobs:
          json_string = json.dumps(job)
          f.write(json_string + "\n")

# ์ด JSONL์€ OpenAI API์— ์ „๋‹ฌํ•˜๋ฉด, ๊ฐ ์ฟผ๋ฆฌ์— ๋Œ€ํ•ด "resolve_yn": "yes" ๋˜๋Š” "no"๋กœ ํ‰๊ฐ€ ์‘๋‹ต์„ ๋ฐ›๊ฒŒ ๋จ

 

 

โ†ช๏ธŽ  ์ž…๋ ฅํ•œ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„(df)๋ฅผ ์ˆœํšŒํ•˜๋ฉฐ, ํ‰๊ฐ€์— ์‚ฌ์šฉํ•  propmt๋ฅผ ์ƒ์„ฑํ•˜๊ณ , jsonlํŒŒ์ผ์— ์š”์ฒญํ•  ๋‚ด์šฉ์„ ๊ธฐ๋กํ•œ๋‹ค. 

โ†ช๏ธŽ  DDL & Question์„ ๋ฐ”ํƒ•์œผ๋กœ LLM์ด ์ƒ์„ฑํ•œ SQL (gen_sql)์ด ์ •๋‹ต SQL๊ณผ ๋™์ผํ•œ ๊ธฐ๋Šฅ์„ ํ•˜๋Š”์ง€ ํ‰๊ฐ€ํ•˜๋„๋ก ํ•œ๋‹ค. 

 

 

 

โ  jsonl ํŒŒ์ผ์„ csv๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜  

 โ†ช๏ธŽ  ํ”„๋กฌํ”„ํŠธ์™€ ํŒ๋‹จ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ๊ฐ (prompts, responses) ๋ณ€์ˆ˜์— ์ €์žฅํ•œ๋‹ค. 

 

def change_jsonl_to_csv(input_file, output_file, prompt_column="prompt", response_column="response"):
    prompts = []
    responses = []
    with open(input_file, 'r') as json_file:
        for data in json_file:
            prompts.append(json.loads(data)[0]['messages'][0]['content'])
            responses.append(json.loads(data)[1]['choices'][0]['message']['content'])

    df = pd.DataFrame({prompt_column: prompts, response_column: responses})
    df.to_csv(output_file, index=False)
    return df

 

 

 

 

 


3. ์‹ค์Šต : ๋ฏธ์„ธ์กฐ์ • ์ˆ˜ํ–‰ํ•˜๊ธฐ    


 

3.1   ๊ธฐ์ดˆ๋ชจ๋ธํ‰๊ฐ€ํ•˜๊ธฐ

 

 

โ  ๊ธฐ์ดˆ๋ชจ๋ธ์„ ์„ ํƒ 

 โ†ช๏ธŽ  ๊ต์žฌ ์‹œ์ ์—์„œ 7B ์ดํ•˜ ํ•œ๊ตญ์–ด ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ ์ค‘ ๊ฐ€์žฅ ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” Yi-Ko-6B ๋ชจ๋ธ์„ ๊ธฐ์ดˆ๋ชจ๋ธ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. 

 

import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

def make_inference_pipeline(model_id):
  tokenizer = AutoTokenizer.from_pretrained(model_id)
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  return pipe

 

 

 โ†ช๏ธŽ  make_inference_pipeline ํ•จ์ˆ˜๋Š” ์ž…๋ ฅํ•œ ๋ชจ๋ธ ์•„์ด๋””์— ๋งž์ถ”์–ด ์‚ฌ์ „ํ•™์Šต๋œ ํ† ํฌ๋‚˜์ด์ €์™€ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ํ•˜๋‚˜์˜ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ๋งŒ๋“ค์–ด ๋ฐ˜ํ™˜ํ•œ๋‹ค. 

 

 

model_id = 'beomi/Yi-Ko-6B'
hf_pipe = make_inference_pipeline(model_id)

example = """๋‹น์‹ ์€ SQL์„ ์ƒ์„ฑํ•˜๋Š” SQL ๋ด‡์ž…๋‹ˆ๋‹ค. DDL์˜ ํ…Œ์ด๋ธ”์„ ํ™œ์šฉํ•œ Question์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” SQL ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜์„ธ์š”.

### DDL:
CREATE TABLE players (
  player_id INT PRIMARY KEY AUTO_INCREMENT,
  username VARCHAR(255) UNIQUE NOT NULL,
  email VARCHAR(255) UNIQUE NOT NULL,
  password_hash VARCHAR(255) NOT NULL,
  date_joined DATETIME NOT NULL,
  last_login DATETIME
);

### Question:
์‚ฌ์šฉ์ž ์ด๋ฆ„์— 'admin'์ด ํฌํ•จ๋˜์–ด ์žˆ๋Š” ๊ณ„์ •์˜ ์ˆ˜๋ฅผ ์•Œ๋ ค์ฃผ์„ธ์š”.

### SQL:
"""

# ์˜ˆ์‹œ ๋ฐ์ดํ„ฐ๋ฅผ ์ž…๋ ฅํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธ 
hf_pipe(example, do_sample=False,
    return_full_text=False, max_length=512, truncation=True)

# ### ๊ฒฐ๊ณผ
#  SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';

# ### SQL ๋ด‡:
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';

# ### SQL ๋ด‡์˜ ๊ฒฐ๊ณผ:
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%'; (์ƒ๋žต)

 

 

 โ†ช๏ธŽ  hf_pipe ๋ณ€์ˆ˜์— ํŒŒ์ดํ”„๋ผ์ธ์„ ์ €์žฅํ•˜๊ณ , example ๋ฐ์ดํ„ฐ๋ฅผ ์ž…๋ ฅํ•ด ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•œ๋‹ค. ์š”์ฒญํ•œ๋Œ€๋กœ SQL์„ ์ž˜ ์ƒ์„ฑํ•˜์ง€๋งŒ, 'SQL ๋ด‡, SQL ๋ด‡์˜ ๊ฒฐ๊ณผ' ๋ผ๋Š” ๋ถ€๋ถ„๊ณผ ๊ฐ™์ด ์ถ”๊ฐ€์ ์ธ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋Š”๋ฐ, ๊ธฐ์ดˆ๋ชจ๋ธ๋„ ์š”์ฒญ์— ๋”ฐ๋ผ SQL์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์€ ์žˆ์œผ๋‚˜, ํ˜•์‹์— ๋งž์ถฐ ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ถ”๊ฐ€ํ•™์Šต์ด ํ•„์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

 

โ  ๊ธฐ์ดˆ๋ชจ๋ธ์„ ํ‰๊ฐ€ 

 โ†ช๏ธŽ  ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•œ SQL์ƒ์„ฑ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  GPT-4๋ฅผ ์‚ฌ์šฉํ•ด ํ‰๊ฐ€ํ•œ๋‹ค. 

 

from datasets import load_dataset

# 1. ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
df = load_dataset("shangrilar/ko_text2sql", "origin")['test']
df = df.to_pandas()

# 2. ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
for idx, row in df.iterrows():
  prompt = make_prompt(row['context'], row['question'])
  df.loc[idx, 'prompt'] = prompt
  
# 3. sql ์ƒ์„ฑ
gen_sqls = hf_pipe(df['prompt'].tolist(), do_sample=False,
                   return_full_text=False, max_length=512, truncation=True)
gen_sqls = [x[0]['generated_text'] for x in gen_sqls]
df['gen_sql'] = gen_sqls

# 4. ํ‰๊ฐ€๋ฅผ ์œ„ํ•œ requests.jsonl ์ƒ์„ฑ
eval_filepath = "text2sql_evaluation.jsonl"
make_requests_for_gpt_evaluation(df, eval_filepath)

 

# 5. GPT-4 ํ‰๊ฐ€ ์ˆ˜ํ–‰
!python api_request_parallel_processor.py \
--requests_filepath requests/{eval_filepath}  \
--save_filepath results/{eval_filepath} \
--request_url https://api.openai.com/v1/chat/completions \
--max_requests_per_minute 2500 \
--max_tokens_per_minute 100000 \
--token_encoding_name cl100k_base \
--max_attempts 5 \
--logging_level 20

 

 

(1) ํ‰๊ฐ€๋ฐ์ดํ„ฐ์…‹์„ load_dataset์œผ๋กœ ๋‚ด๋ ค๋ฐ›๊ณ 

(2) make_prompt ํ•จ์ˆ˜๋กœ LLM ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•œ ํ›„

(3) hf_pipe์— ์ƒ์„ฑํ•œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด SQL์„ ์ƒ์„ฑํ•˜๊ณ  gen_sqls ๋ณ€์ˆ˜์— ์ €์žฅํ•œ ํ›„

(4) ๋งˆ์ง€๋ง‰์œผ๋กœ GPT-4 ํ‰๊ฐ€์— ์‚ฌ์šฉํ•  jsonl ํŒŒ์ผ์„ ๋งŒ๋“ค๊ณ 

(5) GPT-4 API์— ํ‰๊ฐ€๋ฅผ ์š”์ฒญํ•œ๋‹ค. 

 

 โ†ช๏ธŽ  ๊ธฐ์ดˆ๋ชจ๋ธ์— ๋Œ€ํ•œ ํ‰๊ฐ€๋ฅผ ์ง„ํ–‰ํ–ˆ์„ ๋•Œ 112๊ฐœ ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹ ์ค‘ 21๊ฐœ๋ฅผ ์ •๋‹ต์œผ๋กœ ํŒ๋‹จํ–ˆ๋‹ค. 

 

 

 

 

3.2   ๋ฏธ์„ธ์กฐ์ • ์ˆ˜ํ–‰ 

 

โ  ๋ฏธ์„ธ์กฐ์ •

 โ†ช๏ธŽ  ํ•™์Šต ๋ฐ์ดํ„ฐ๋กœ fine tuning์„ ์ง„ํ–‰ํ•œ๋‹ค. autotrain-advanced ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. 

 

 

1) ํ•™์Šต ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ 

 

from datasets import load_dataset

df_sql = load_dataset("shangrilar/ko_text2sql", "origin")["train"]
df_sql = df_sql.to_pandas()
df_sql = df_sql.dropna().sample(frac=1, random_state=42)
df_sql = df_sql.query("db_id != 1") # ํ‰๊ฐ€์— ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ๋Š” ์ œ์™ธ

for idx, row in df_sql.iterrows():
  df_sql.loc[idx, 'text'] = make_prompt(row['context'], row['question'], row['answer'])

!mkdir data
df_sql.to_csv('data/train.csv', index=False)

 

 

 

2) autotrain-advanced ๋ฅผ ์ด์šฉํ•ด ๋ฏธ์„ธ์กฐ์ • 

 

base_model = 'beomi/Yi-Ko-6B'
finetuned_model = 'yi-ko-6b-text2sql'

# autotrain
!autotrain llm \
--train \
--model {base_model} \
--project-name {finetuned_model} \
--data-path data/ \
--text-column text \
--lr 2e-4 \
--batch-size 8 \
--epochs 1 \
--block-size 1024 \
--warmup-ratio 0.1 \
--lora-r 16 \
--lora-alpha 32 \
--lora-dropout 0.05 \
--weight-decay 0.01 \
--gradient-accumulation 8 \
--mixed-precision fp16 \
--use-peft \
--quantization int4 \
--trainer sft

 

 

 

3) ํ•™์Šต ํ›„, LoRA ์–ด๋Œ‘ํ„ฐ์™€ ๊ธฐ์ดˆ๋ชจ๋ธ์„ ํ•ฉ์น˜๊ธฐ 

 

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftModel

model_name = base_model
device_map = {"": 0}

# LoRA์™€ ๊ธฐ์ดˆ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ํ•ฉ์น˜๊ธฐ
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, finetuned_model)
model = model.merge_and_unload()

# ํ† ํฌ๋‚˜์ด์ € ์„ค์ •
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์— ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ €์žฅ
model.push_to_hub(finetuned_model, use_temp_dir=False)
tokenizer.push_to_hub(finetuned_model, use_temp_dir=False)

 

*LoRA : ์ผ๋ถ€ ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ ์ถ”๊ฐ€ํ•™์Šต ํ•˜๋Š” ๋ฐฉ๋ฒ• (๋น ๋ฅด๊ณ  ์ €๋ ดํ•˜๊ฒŒ fine tuning์„ ํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•จ) 

 

 

 

4) ๋ฏธ์„ธ์กฐ์ •ํ•œ ๋ชจ๋ธ๋กœ ์˜ˆ์‹œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ SQL ์ƒ์„ฑ

 

model_id = "shangrilar/yi-ko-6b-text2sql"
hf_pipe = make_inference_pipeline(model_id)

hf_pipe(example, do_sample=False,
       return_full_text=False, max_length=1024, truncation=True)
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';

 

 

 

5) ๋ฏธ์„ธ์กฐ์ •ํ•œ ๋ชจ๋ธ ์„ฑ๋Šฅ ์ธก์ •

 

# sql ์ƒ์„ฑ ์ˆ˜ํ–‰
gen_sqls = hf_pipe(df['prompt'].tolist(), do_sample=False,
                   return_full_text=False, max_length=1024, truncation=True)
gen_sqls = [x[0]['generated_text'] for x in gen_sqls]
df['gen_sql'] = gen_sqls

# ํ‰๊ฐ€๋ฅผ ์œ„ํ•œ requests.jsonl ์ƒ์„ฑ
ft_eval_filepath = "text2sql_evaluation_finetuned.jsonl"
make_requests_for_gpt_evaluation(df, ft_eval_filepath)

# GPT-4 ํ‰๊ฐ€ ์ˆ˜ํ–‰
!python api_request_parallel_processor.py \
  --requests_filepath requests/{ft_eval_filepath} \
  --save_filepath results/{ft_eval_filepath} \
  --request_url https://api.openai.com/v1/chat/completions \
  --max_requests_per_minute 2500 \
  --max_tokens_per_minute 100000 \
  --token_encoding_name cl100k_base \
  --max_attempts 5 \
  --logging_level 20
  
 
ft_eval = change_jsonl_to_csv(f"results/{ft_eval_filepath}", "results/yi_ko_6b_eval.csv", "prompt", "resolve_yn")
ft_eval['resolve_yn'] = ft_eval['resolve_yn'].apply(lambda x: json.loads(x)['resolve_yn'])
num_correct_answers = ft_eval.query("resolve_yn == 'yes'").shape[0]
num_correct_answers

*์ •๋‹ต๋ฅ ์ด 60%์ด์ƒ๊นŒ์ง€ ๋‚˜์˜ด 

 

 

 

 


3.3   ํ•™์Šต ๋ฐ์ดํ„ฐ ์ •์ œ์™€ ๋ฏธ์„ธ์กฐ์ • 

 

โ  ํ•™์Šต ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ์ž์ฒด๋ฅผ ๊ฐœ์„ ํ–ˆ์„ ๋•Œ์˜ ํšจ๊ณผ 

 โ†ช๏ธŽ  GPT-4๋ฅผ ์‚ฌ์šฉํ•ด ํ•„ํ„ฐ๋ง์„ ์ˆ˜ํ–‰ : ์ •์ œํ•œ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ fine tuning์„ ํ–ˆ์„ ๋•Œ, ๋ฐ์ดํ„ฐ์…‹ ํฌ๊ธฐ๊ฐ€ 1/4์ •๋„ ์ค„์–ด๋“ค์—ˆ์œผ๋‚˜ ์ •์ œ ์ „์˜ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šตํ–ˆ์„ ๋•Œ์™€ ๋™์ผํ•œ ์„ฑ๋Šฅ์„ ๋ณด์˜€๋‹ค โ€ฃ ์ •์ œ์˜ ๊ธ์ •์  ํšจ๊ณผ!

 

 

 

 

 

 

 

728x90

๋Œ“๊ธ€