๐ Text2SQL
1. ๋ฐ์ดํฐ์ ๊ตฌ์ถ
2. ๋ชจ๋ธ ์ฑ๋ฅํ๊ฐ์ ์ฌ์ฉํ ํ์ดํ๋ผ์ธ ๊ตฌ์ถ
โบ sLLM : ํน์ ์์ ๋๋ ๋๋ฉ์ธ์ ํนํ๋ LLM


1. Text2SQL ๋ฐ์ดํฐ์
1.1 ๋ํ์ ์ธ Text2SQL๋ฐ์ดํฐ์
โ Text2SQL ๋ฐ์ดํฐ์
โช๏ธ ํ ์ด๋ธ ๋ฐ ์นผ๋ผ ์ ๋ณด (๋ฐ์ดํฐ๋ฒ ์ด์ค ์ ๋ณด)
โช๏ธ ์์ฒญ์ฌํญ (์ด๋ค ๋ฐ์ดํฐ๋ฅผ ์ถ์ถํ๊ณ ์ถ์์ง)
โ WikiSQL
โช๏ธ [ํ ์ด๋ธ๋ช , ์นผ๋ผ๋ช , ์นผ๋ผํ์, ์์ฒญ์ฌํญ(์์ฐ์ด), ์ ๋ตSQL์ฟผ๋ฆฌ๋ฌธ] ์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ์ , ๋น๊ต์ ๊ฐ๋จํ ์ฟผ๋ฆฌ๋ฌธ ์์๋ก ๊ตฌ์ฑ๋์ด ์์
โ Spider
โช๏ธ ๋ณต์กํ SQL๋ฌธ๋ ํฌํจ๋ ๋ฐ์ดํฐ์
1.2 ํ๊ตญ์ด ๋ฐ์ดํฐ์
โ NL2SQL
โช๏ธ ๋น ์ฟผ๋ฆฌ์ ํ์ดํผํด๋ก๋ฐX๋ก ๊ตฌํํ๋ ์์
1.3 ํฉ์ฑ ๋ฐ์ดํฐ ํ์ฉ
โ ์ค์ต์ ์ํด ์์ฑํ ์์ ๋ฐ์ดํฐ
โช๏ธ db_id : ๋์ผํ id๋ฅผ ๊ฐ๋ ํ ์ด๋ธ์ ๊ฐ์ ๋๋ฉ์ธ (ex.๊ฒ์ - ๊ฒ์์ธ ์ํฉ์ ๊ฐ์ ํ๊ณ ์์ฑํ ํ ์ด๋ธ) ์ ๊ณต์ ํจ
โช๏ธ context : SQL์์ฑ์ ์ฌ์ฉํ ํ ์ด๋ธ ์ ๋ณด
โช๏ธ question : ๋ฐ์ดํฐ ์์ฒญ์ฌํญ
โช๏ธ answer : ์์ฒญ์ ๋ํ SQL ์ ๋ต

2. ์ฑ๋ฅ ํ๊ฐ ํ์ดํ๋ผ์ธ ์ค๋นํ๊ธฐ
๐ ์ค์ต์์๋ ์ผ๋ฐ์ ์ผ๋ก ์ฌ์ฉ๋๋ Text2SQL ํ๊ฐ๋ฐฉ์์ ์ฌ์ฉํ์ง ์๊ณ , GPT-4๋ฅผ ์ฌ์ฉํ์ฌ, ์์ฑ๋ SQL์ด ์ ๋ต์ธ์ง ํ๋จํ๋ ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
2.1 Text2SQL ํ๊ฐ๋ฐฉ์
โ EM๋ฐฉ์
โช๏ธ ์์ฑํ SQL์ด ๋ฌธ์์ด ๊ทธ๋๋ก ๋์ผํ์ง ํ์ธํ๋ ๋ฐฉ์
โช๏ธ ์๋ฏธ์์ผ๋ก ๋์ผํ SQL์ฟผ๋ฆฌ๊ฐ ๋ค์ํ๊ฒ ๋์ฌ ์ ์๋๋ฐ ๋จ์ํ ๋ฌธ์์ด์ด ๋ค๋ฅด๋ฉด ๋ค๋ฅด๋ค๊ณ ํ๋จํ๋ ๋ฌธ์ ๋ฐ์
โ EX๋ฐฉ์
โช๏ธ ์คํ ์ ํ๋ ๋ฐฉ์์ผ๋ก, ์ฟผ๋ฆฌ๋ฅผ ์ํํ ์ ์๋ DB๋ฅผ ๋ง๋ค๊ณ , ํ๋ก๊ทธ๋๋ฐ ๋ฐฉ์์ผ๋ก SQL์ฟผ๋ฆฌ๋ฅผ ์ํํด ์ ๋ต๊ณผ ์ผ์นํ๋์ง ํ์ธํ๋ ๋ฐฉ์
โช๏ธ DB๋ฅผ ์ถ๊ฐ๋ก ์ค๋นํด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋ฒ๊ฑฐ๋ก์
โ LLM์ LLM์ผ๋ก ํ๊ฐํ์
โช๏ธ ์ต๊ทผ์๋ LLM์ ํ์ฉํด LLM์ ์์ฑ ๊ฒฐ๊ณผ๋ฅผ ํ๊ฐํ๋ ๋ฐฉ์์ด ํ๋ฐํ ์ฐ๊ตฌ๋๊ณ ์๋ค. ์ด๋ฒ ์ค์ต์์๋ LLM์ด ์์ฑํ SQL์ด ๋ฐ์ดํฐ ์์ฒญ์ ์ ํด๊ฒฐํ๋์ง GPT-4๋ฅผ ์ฌ์ฉํด ํ์ธํ๋ค.
2.2 ํ๊ฐ ๋ฐ์ดํฐ์ ๊ตฌ์ถ
โ 8๊ฐ ๋ฐ์ดํฐ๋ฒ ์ด์ค(๋๋ฉ์ธ)์ ๋ํด ํฉ์ฑ ๋ฐ์ดํฐ์ ์ ์์ฑ
โช๏ธ 7๊ฐ๋ ํ์ต์ ์ฌ์ฉ, 1๊ฐ (๊ฒ์๋๋ฉ์ธ)์ ํ๊ฐ์ ์ฌ์ฉ
โช๏ธ 112๊ฐ์ ํ๊ฐ๋ฐ์ดํฐ
2.3 SQL์์ฑ ํ๋กฌํํธ
โ ํ๋กฌํํธ ์ค๋น
โช๏ธ LLM์ด SQL์ ์์ฑํ๋๋ก ํ๊ธฐ ์ํด, ์ง์์ฌํญ๊ณผ ๋ฐ์ดํฐ๋ฅผ ํฌํจํ ํ๋กฌํํธ๋ฅผ ์ค๋นํด์ผ ํ๋ค. make_prompt ํจ์๋ฅผ ์ฌ์ฉํด ๋ช ๋ น prompt์ DDL (ํ ์ด๋ธ ์์ฑ๋ฌธ), ์ง๋ฌธ, ์ ๋ตSQL๋ฌธ์ ์ป๋๋ค. ํ์ต์ ์ฌ์ฉ๋๋ ํํ๋ SQL ์ฟผ๋ฆฌ๊ฐ ์กด์ฌํ๋ ํํ๊ณ , SQL์ ์์ฑํ๋ ์์ ์ ํ ๋์๋ query๋ฅผ ์ ๋ ฅํ์ง ์๋๋ค (์ ๋ต SQL์ ๋ฃ์ง ์์).
def make_prompt(ddl, question, query=''):
prompt = f"""๋น์ ์ SQL์ ์์ฑํ๋ SQL ๋ด์
๋๋ค. DDL์ ํ
์ด๋ธ์ ํ์ฉํ Question์ ํด๊ฒฐํ ์ ์๋ SQL ์ฟผ๋ฆฌ๋ฅผ ์์ฑํ์ธ์.
### DDL:
{ddl}
### Question:
{question}
### SQL:
{query}"""
return prompt
2.4 GPT-4 ํ๊ฐ ํ๋กฌํํธ์ ์ฝ๋ ์ค๋น
โ GPT-4๋ฅผ ํตํ ํ๊ฐ ์ํ
โช๏ธ GPT-4๋ฅผ ์ฌ์ฉํด ํ๊ฐ๋ฅผ ์ํํ๋ ค๊ณ ํ ๋, ๋ฐ๋ณต์ ์ผ๋ก API์์ฒญ์ ๋ณด๋ด์ผ ํ๋ค.
โช๏ธ OpenAI๊ฐ ์ ๊ณตํ๋ ์ฝ๋๋ฅผ ํ์ฉํ๋ฉด ์์ฒญ ์ ํ์ ๊ด๋ฆฌํ๋ฉด์ ๋น๋๊ธฐ์ ์ผ๋ก ์์ฒญ์ ๋ณด๋ด ๋ฐ๋ณต์ ์ธ API ์์ฒญ์ ์ฒ๋ฆฌํ ์ ์๋ค.
โช๏ธ (make_requests_for_gpt_evaluation) ํจ์๋ฅผ ํตํด ํ๊ฐ ๋ฐ์ดํฐ์ ์ ์ฝ์ด GPT-4 API์์ฒญ์ ๋ณด๋ผ jsonl ํ์ผ์ ์์ฑํ๋ค. ์์ฑ๋ jsonํ์ผ์ OpenAI๊ฐ ์ ๊ณตํ๋ (api_request_parallel_processor.py) ์ฝ๋๋ฅผ ํตํด ์์ฐจ์ ์ผ๋ก ์์ฒญ์ ๋ณด๋ธ๋ค.
*jsonl ํ์ผ : JSON Lines ํฌ๋งท์ ํ์ผ์ ์๋ฏธํ๋ค. .json๊ณผ๋ ๋ค๋ฅด๊ฒ, ํ์ค์ ํ๋์ JSON ๊ฐ์ฒด๊ฐ ์๋ ๊ตฌ์กฐ์ด๋ค.

โ ํ๊ฐ๋ฅผ ์ํ ์์ฒญ jsonl ์์ฑ ํจ์
โช๏ธ LLM์ ์ง๋ฌธ์ "๋ฐฐ์น(batch)"๋ก ํ๊ฐ ์์ฒญํ๋ ค๋ฉด, ์ง๋ฌธ๋ค์ ์ค ๋จ์(JSONL ํ์)๋ก ์ ๋ฆฌํด์ผ ํ๋๋ฐ, jsonl ๊ฐ์ฒด์ ํ ์ค์ด ํ ์ง๋ฌธ ํ๊ฐ์์ฒญ์ด ๋๋๋ก ๋ฐ์ดํฐ๋ฅผ ๊ตฌ์ฑํ๋ค.
import json
import pandas as pd
from pathlib import Path
def make_requests_for_gpt_evaluation(df, filename, dir='requests'):
if not Path(dir).exists():
Path(dir).mkdir(parents=True)
prompts = []
for idx, row in df.iterrows():
prompts.append("""Based on below DDL and Question, evaluate gen_sql can resolve Question. If gen_sql and gt_sql do equal job, return "yes" else return "no". Output JSON Format: {"resolve_yn": ""}""" + f"""
DDL: {row['context']}
Question: {row['question']}
gt_sql: {row['answer']}
gen_sql: {row['gen_sql']}"""
)
# ๊ฐ ํ๋กฌํํธ๋ฅผ GPT-4 ์์ฒญ์ ์ ํฉํ ์ค ๋จ์ JSON ํ์์ผ๋ก ์ ๋ฆฌ
jobs = [{"model": "gpt-4-turbo-preview", "response_format" : { "type": "json_object" }, "messages": [{"role": "system", "content": prompt}]} for prompt in prompts]
# jsonl ํ์์ผ๋ก ์ ์ฅ
with open(Path(dir, filename), "w") as f:
for job in jobs:
json_string = json.dumps(job)
f.write(json_string + "\n")
# ์ด JSONL์ OpenAI API์ ์ ๋ฌํ๋ฉด, ๊ฐ ์ฟผ๋ฆฌ์ ๋ํด "resolve_yn": "yes" ๋๋ "no"๋ก ํ๊ฐ ์๋ต์ ๋ฐ๊ฒ ๋จ
โช๏ธ ์ ๋ ฅํ ๋ฐ์ดํฐํ๋ ์(df)๋ฅผ ์ํํ๋ฉฐ, ํ๊ฐ์ ์ฌ์ฉํ propmt๋ฅผ ์์ฑํ๊ณ , jsonlํ์ผ์ ์์ฒญํ ๋ด์ฉ์ ๊ธฐ๋กํ๋ค.
โช๏ธ DDL & Question์ ๋ฐํ์ผ๋ก LLM์ด ์์ฑํ SQL (gen_sql)์ด ์ ๋ต SQL๊ณผ ๋์ผํ ๊ธฐ๋ฅ์ ํ๋์ง ํ๊ฐํ๋๋ก ํ๋ค.
โ jsonl ํ์ผ์ csv๋ก ๋ณํํ๋ ํจ์
โช๏ธ ํ๋กฌํํธ์ ํ๋จ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ๊ฐ (prompts, responses) ๋ณ์์ ์ ์ฅํ๋ค.
def change_jsonl_to_csv(input_file, output_file, prompt_column="prompt", response_column="response"):
prompts = []
responses = []
with open(input_file, 'r') as json_file:
for data in json_file:
prompts.append(json.loads(data)[0]['messages'][0]['content'])
responses.append(json.loads(data)[1]['choices'][0]['message']['content'])
df = pd.DataFrame({prompt_column: prompts, response_column: responses})
df.to_csv(output_file, index=False)
return df
3. ์ค์ต : ๋ฏธ์ธ์กฐ์ ์ํํ๊ธฐ
3.1 ๊ธฐ์ด๋ชจ๋ธํ๊ฐํ๊ธฐ
โ ๊ธฐ์ด๋ชจ๋ธ์ ์ ํ
โช๏ธ ๊ต์ฌ ์์ ์์ 7B ์ดํ ํ๊ตญ์ด ์ฌ์ ํ์ต ๋ชจ๋ธ ์ค ๊ฐ์ฅ ๋์ ์ฑ๋ฅ์ ๋ณด์ด๋ Yi-Ko-6B ๋ชจ๋ธ์ ๊ธฐ์ด๋ชจ๋ธ๋ก ์ฌ์ฉํ๋ค.
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
def make_inference_pipeline(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
return pipe
โช๏ธ make_inference_pipeline ํจ์๋ ์ ๋ ฅํ ๋ชจ๋ธ ์์ด๋์ ๋ง์ถ์ด ์ฌ์ ํ์ต๋ ํ ํฌ๋์ด์ ์ ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๊ณ ํ๋์ ํ์ดํ๋ผ์ธ์ผ๋ก ๋ง๋ค์ด ๋ฐํํ๋ค.
model_id = 'beomi/Yi-Ko-6B'
hf_pipe = make_inference_pipeline(model_id)
example = """๋น์ ์ SQL์ ์์ฑํ๋ SQL ๋ด์
๋๋ค. DDL์ ํ
์ด๋ธ์ ํ์ฉํ Question์ ํด๊ฒฐํ ์ ์๋ SQL ์ฟผ๋ฆฌ๋ฅผ ์์ฑํ์ธ์.
### DDL:
CREATE TABLE players (
player_id INT PRIMARY KEY AUTO_INCREMENT,
username VARCHAR(255) UNIQUE NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL,
date_joined DATETIME NOT NULL,
last_login DATETIME
);
### Question:
์ฌ์ฉ์ ์ด๋ฆ์ 'admin'์ด ํฌํจ๋์ด ์๋ ๊ณ์ ์ ์๋ฅผ ์๋ ค์ฃผ์ธ์.
### SQL:
"""
# ์์ ๋ฐ์ดํฐ๋ฅผ ์
๋ ฅํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ํ์ธ
hf_pipe(example, do_sample=False,
return_full_text=False, max_length=512, truncation=True)
# ### ๊ฒฐ๊ณผ
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';
# ### SQL ๋ด:
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';
# ### SQL ๋ด์ ๊ฒฐ๊ณผ:
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%'; (์๋ต)
โช๏ธ hf_pipe ๋ณ์์ ํ์ดํ๋ผ์ธ์ ์ ์ฅํ๊ณ , example ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅํด ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ๋ค. ์์ฒญํ๋๋ก SQL์ ์ ์์ฑํ์ง๋ง, 'SQL ๋ด, SQL ๋ด์ ๊ฒฐ๊ณผ' ๋ผ๋ ๋ถ๋ถ๊ณผ ๊ฐ์ด ์ถ๊ฐ์ ์ธ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ ๊ฒ์ ๋ณผ ์ ์๋๋ฐ, ๊ธฐ์ด๋ชจ๋ธ๋ ์์ฒญ์ ๋ฐ๋ผ SQL์ ์์ฑํ ์ ์๋ ๋ฅ๋ ฅ์ ์์ผ๋, ํ์์ ๋ง์ถฐ ๋ต๋ณํ๋๋ก ํ๊ธฐ ์ํด์๋ ์ถ๊ฐํ์ต์ด ํ์ํ๋ค๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
โ ๊ธฐ์ด๋ชจ๋ธ์ ํ๊ฐ
โช๏ธ ํ๊ฐ ๋ฐ์ดํฐ์ ์ ๋ํ SQL์์ฑ์ ์ํํ๊ณ GPT-4๋ฅผ ์ฌ์ฉํด ํ๊ฐํ๋ค.
from datasets import load_dataset
# 1. ๋ฐ์ดํฐ์
๋ถ๋ฌ์ค๊ธฐ
df = load_dataset("shangrilar/ko_text2sql", "origin")['test']
df = df.to_pandas()
# 2. ํ๋กฌํํธ ์์ฑ
for idx, row in df.iterrows():
prompt = make_prompt(row['context'], row['question'])
df.loc[idx, 'prompt'] = prompt
# 3. sql ์์ฑ
gen_sqls = hf_pipe(df['prompt'].tolist(), do_sample=False,
return_full_text=False, max_length=512, truncation=True)
gen_sqls = [x[0]['generated_text'] for x in gen_sqls]
df['gen_sql'] = gen_sqls
# 4. ํ๊ฐ๋ฅผ ์ํ requests.jsonl ์์ฑ
eval_filepath = "text2sql_evaluation.jsonl"
make_requests_for_gpt_evaluation(df, eval_filepath)
# 5. GPT-4 ํ๊ฐ ์ํ
!python api_request_parallel_processor.py \
--requests_filepath requests/{eval_filepath} \
--save_filepath results/{eval_filepath} \
--request_url https://api.openai.com/v1/chat/completions \
--max_requests_per_minute 2500 \
--max_tokens_per_minute 100000 \
--token_encoding_name cl100k_base \
--max_attempts 5 \
--logging_level 20
(1) ํ๊ฐ๋ฐ์ดํฐ์ ์ load_dataset์ผ๋ก ๋ด๋ ค๋ฐ๊ณ
(2) make_prompt ํจ์๋ก LLM ์ถ๋ก ์ ์ฌ์ฉํ ํ๋กฌํํธ๋ฅผ ์์ฑํ ํ
(3) hf_pipe์ ์์ฑํ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด SQL์ ์์ฑํ๊ณ gen_sqls ๋ณ์์ ์ ์ฅํ ํ
(4) ๋ง์ง๋ง์ผ๋ก GPT-4 ํ๊ฐ์ ์ฌ์ฉํ jsonl ํ์ผ์ ๋ง๋ค๊ณ
(5) GPT-4 API์ ํ๊ฐ๋ฅผ ์์ฒญํ๋ค.
โช๏ธ ๊ธฐ์ด๋ชจ๋ธ์ ๋ํ ํ๊ฐ๋ฅผ ์งํํ์ ๋ 112๊ฐ ํ๊ฐ ๋ฐ์ดํฐ์ ์ค 21๊ฐ๋ฅผ ์ ๋ต์ผ๋ก ํ๋จํ๋ค.
3.2 ๋ฏธ์ธ์กฐ์ ์ํ
โ ๋ฏธ์ธ์กฐ์
โช๏ธ ํ์ต ๋ฐ์ดํฐ๋ก fine tuning์ ์งํํ๋ค. autotrain-advanced ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ค.
1) ํ์ต ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
from datasets import load_dataset
df_sql = load_dataset("shangrilar/ko_text2sql", "origin")["train"]
df_sql = df_sql.to_pandas()
df_sql = df_sql.dropna().sample(frac=1, random_state=42)
df_sql = df_sql.query("db_id != 1") # ํ๊ฐ์ ์ฌ์ฉํ ๋ฐ์ดํฐ๋ ์ ์ธ
for idx, row in df_sql.iterrows():
df_sql.loc[idx, 'text'] = make_prompt(row['context'], row['question'], row['answer'])
!mkdir data
df_sql.to_csv('data/train.csv', index=False)
2) autotrain-advanced ๋ฅผ ์ด์ฉํด ๋ฏธ์ธ์กฐ์
base_model = 'beomi/Yi-Ko-6B'
finetuned_model = 'yi-ko-6b-text2sql'
# autotrain
!autotrain llm \
--train \
--model {base_model} \
--project-name {finetuned_model} \
--data-path data/ \
--text-column text \
--lr 2e-4 \
--batch-size 8 \
--epochs 1 \
--block-size 1024 \
--warmup-ratio 0.1 \
--lora-r 16 \
--lora-alpha 32 \
--lora-dropout 0.05 \
--weight-decay 0.01 \
--gradient-accumulation 8 \
--mixed-precision fp16 \
--use-peft \
--quantization int4 \
--trainer sft
3) ํ์ต ํ, LoRA ์ด๋ํฐ์ ๊ธฐ์ด๋ชจ๋ธ์ ํฉ์น๊ธฐ
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftModel
model_name = base_model
device_map = {"": 0}
# LoRA์ ๊ธฐ์ด ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ํฉ์น๊ธฐ
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, finetuned_model)
model = model.merge_and_unload()
# ํ ํฌ๋์ด์ ์ค์
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# ํ๊น
ํ์ด์ค ํ๋ธ์ ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ์ ์ฅ
model.push_to_hub(finetuned_model, use_temp_dir=False)
tokenizer.push_to_hub(finetuned_model, use_temp_dir=False)
*LoRA : ์ผ๋ถ ํ๋ผ๋ฏธํฐ๋ง ์ถ๊ฐํ์ต ํ๋ ๋ฐฉ๋ฒ (๋น ๋ฅด๊ณ ์ ๋ ดํ๊ฒ fine tuning์ ํ ์ ์๊ฒ ํจ)
4) ๋ฏธ์ธ์กฐ์ ํ ๋ชจ๋ธ๋ก ์์ ๋ฐ์ดํฐ์ ๋ํ SQL ์์ฑ
model_id = "shangrilar/yi-ko-6b-text2sql"
hf_pipe = make_inference_pipeline(model_id)
hf_pipe(example, do_sample=False,
return_full_text=False, max_length=1024, truncation=True)
# SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';
5) ๋ฏธ์ธ์กฐ์ ํ ๋ชจ๋ธ ์ฑ๋ฅ ์ธก์
# sql ์์ฑ ์ํ
gen_sqls = hf_pipe(df['prompt'].tolist(), do_sample=False,
return_full_text=False, max_length=1024, truncation=True)
gen_sqls = [x[0]['generated_text'] for x in gen_sqls]
df['gen_sql'] = gen_sqls
# ํ๊ฐ๋ฅผ ์ํ requests.jsonl ์์ฑ
ft_eval_filepath = "text2sql_evaluation_finetuned.jsonl"
make_requests_for_gpt_evaluation(df, ft_eval_filepath)
# GPT-4 ํ๊ฐ ์ํ
!python api_request_parallel_processor.py \
--requests_filepath requests/{ft_eval_filepath} \
--save_filepath results/{ft_eval_filepath} \
--request_url https://api.openai.com/v1/chat/completions \
--max_requests_per_minute 2500 \
--max_tokens_per_minute 100000 \
--token_encoding_name cl100k_base \
--max_attempts 5 \
--logging_level 20
ft_eval = change_jsonl_to_csv(f"results/{ft_eval_filepath}", "results/yi_ko_6b_eval.csv", "prompt", "resolve_yn")
ft_eval['resolve_yn'] = ft_eval['resolve_yn'].apply(lambda x: json.loads(x)['resolve_yn'])
num_correct_answers = ft_eval.query("resolve_yn == 'yes'").shape[0]
num_correct_answers
*์ ๋ต๋ฅ ์ด 60%์ด์๊น์ง ๋์ด
3.3 ํ์ต ๋ฐ์ดํฐ ์ ์ ์ ๋ฏธ์ธ์กฐ์
โ ํ์ต ๋ฐ์ดํฐ ํ์ง ์์ฒด๋ฅผ ๊ฐ์ ํ์ ๋์ ํจ๊ณผ
โช๏ธ GPT-4๋ฅผ ์ฌ์ฉํด ํํฐ๋ง์ ์ํ : ์ ์ ํ ๋ฐ์ดํฐ์ ์ผ๋ก fine tuning์ ํ์ ๋, ๋ฐ์ดํฐ์ ํฌ๊ธฐ๊ฐ 1/4์ ๋ ์ค์ด๋ค์์ผ๋ ์ ์ ์ ์ ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ตํ์ ๋์ ๋์ผํ ์ฑ๋ฅ์ ๋ณด์๋ค โฃ ์ ์ ์ ๊ธ์ ์ ํจ๊ณผ!
'1๏ธโฃ AIโขDS > ๐ LLM' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
| [์ฑ ์คํฐ๋] 8. sLLM ์๋นํ๊ธฐ (0) | 2025.09.06 |
|---|---|
| [์ฑ ์คํฐ๋] 7. ๋ชจ๋ธ ๊ฐ๋ณ๊ฒ ๋ง๋ค๊ธฐ (3) | 2025.08.23 |
| [์ฑ ์คํฐ๋] 5-2. GPU ํจ์จ์ ์ธ ํ์ต (4) | 2025.08.02 |
| [์ฑ ์คํฐ๋] 5-1. GPU ํจ์จ์ ์ธ ํ์ต (3) | 2025.07.24 |
| [์ฑ ์คํฐ๋] 4. GPT-3๊ฐ ์ฑGPT๋ก ๋ฐ์ ํ ์ ์์๋ ๋ฐฉ๋ฒ (0) | 2025.07.06 |
๋๊ธ