반응형
[AIF] 개인 데이터셋을 통한 llama2 fine-tune
- AIF에서 발표한 내용을 다시 작성한 Post입니다, 해당 강의는 llama2를 사용한 fine-tune의 간편화에 목적을 두었습니다.
- 강연에서는 Gradientai, SFTTrainer 을 예시로 사용했습니다.
- https://github.com/lxe/simple-llm-finetuner
- 해당 github의 simple-llm-finetuner가 llama2 모델을 지원하지 않고, 추가 업데이트 없이 종료되어 해당 github의 코드를 사용하여 easy_finetuner를 만들었습니다. (https://github.com/choijhyeok/easy_finetuner?tab=readme-ov-file)
- 해당 github의 Trainer를 SFTTrainer로 변경
- parameter를 Colab의 T4상에서 훈련할수 있도록 변경
- SFTTraienr parameter, T4 GPU 최적화 setting등 내부 config 정리
- 필요없는 기능 제거, Gradio에서 데이터셋 미리보기 기능 추가
- Huggingface의 SSH KEY 입력 추가
- 해당 github의 simple-llm-finetuner가 llama2 모델을 지원하지 않고, 추가 업데이트 없이 종료되어 해당 github의 코드를 사용하여 easy_finetuner를 만들었습니다. (https://github.com/choijhyeok/easy_finetuner?tab=readme-ov-file)
현재 Gradientai는 정상 작동하지 않아서 easy_finetuner를 설명하겠습니다.
Easy Finetuner 실행
- 사용한 패키지의 버전은 제 github의 https://github.com/choijhyeok/easy_finetuner 에서 requirements.txt를 참고해 주세요.
- 해당 github를 git clone하고, 양식에 맞게 데이터 저장을 우선적으로 수행해야 합니다.
DATASET 생성
- 아래 코드와 같이 text 혹은 jsonl등등 text가 담겨있는 데이터를 huggingface의 Dataset형식으로 생성해 주세요.
- 이때 중요한건 열 이름은 반드시 'text' 하나로 된 단일 열에 "instruction
~ response ~"과 같이 text 형태로 넣어주시면 됩니다.- 단일 열로 만든 이유는 dataset마다 열이름이 달라서 하나로 통일하고자 수행했습니다. 기본적으로 SFTTrainer를 사용하게 만들었기 때문에 Instruction Context Question Answer 혹은 Instruction Question Answer 형태로 작업해주시면 됩니다.
bergerking_dataset = []
with jsonlines.open("/content/qa_버거킹_train.jsonl") as f:
for line in f.iter():
# bergerking_dataset.append(f'<s>[INST] {line["inputs"]} [/INST] {line["response"]} </s>')
bergerking_dataset.append(f'<s>### Instruction: \n{line["inputs"]} \n\n### Response: \n{line["response"]}</s>')
# 데이터셋 확인
print('데이터셋 확인')
print(bergerking_dataset[:5])
# 데이터셋 생성 및 저장
burgerking_dataset = Dataset.from_dict({"text": bergerking_dataset})
burgerking_dataset.save_to_disk('/content/easy_finetuner/example-datasets/burgerking_dataset')
# 데이터셋 info 확인
print('데이터셋 info 확인')
print(burgerking_dataset)
실행화면 설명
SFTTraienr의 parameter를 설정하여 Lora Fine-Tune을 수행
Easy Finetuner github의 config에 아래와 같이 설명이 추가되어 있습니다.
parser.add_argument('--num_train_epochs', type=int, default=1, help='epochs 수') parser.add_argument('--learning_rate', type=float, default=2e-4, help='Learning rate') parser.add_argument('--output_dir', type=str, default='./result', help='저장위치') parser.add_argument('--per_device_train_batch_size', type=int, default=4, help='훈련당 train gpu batch size') ...
%cd /content/easy_finetuner
!python app.py --share
fine-tune llama2 RAG
- 학습한 adapter를 base_model에 연결해서 RAG를 수행한 과정
from transformers import AutoTokenizer, LlamaForCausalLM
import torch
import argparse
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
pipeline,
logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import warnings
warnings.filterwarnings('ignore')
def llama2_prompt(input_text):
return f'### Instruction:\n{input_text}\n\n### Response:'
def llama2_output(ouput_text):
sep = ouput_text[0]['generated_text'].split('### Response:')[1].split('### Instruction')[0].split('## Instruction')[0].split('# Instruction')[0].split('Instruction')[0]
sep = sep[1:] if sep[0] == '.' else sep
sep = sep[:sep.find('.')+1] if '.' in sep else sep
return sep
adapter_name = 'beomi_llama-2-ko-7b_burgerking-ko'
compute_dtype = getattr(torch, 'float16')
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=False
)
model = AutoModelForCausalLM.from_pretrained('beomi/llama-2-ko-7b', quantization_config=bnb_config, device_map={'': 0}, use_auth_token=os.environ["huggingface_token"])
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained('beomi/llama-2-ko-7b', trust_remote_code=True, use_auth_token=os.environ["huggingface_token"])
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = PeftModel.from_pretrained(model, f'/content/easy_finetuner/lora/{adapter_name}')
Huggingface Pipeline
- huggingface의 pipeline을 통해서 추론
pipe = pipeline(task="text-generation",
model=model,
tokenizer=tokenizer,
max_length=150,
do_sample=True,
temperature=0.1,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
top_k=3,
# top_p=0.3,
repetition_penalty = 1.3,
framework='pt'
# early_stopping=True
)
prompt = "버거킹에서 판매하는 스프라이트 제로의 특징은 무엇인가요?"
result = pipe(llama2_prompt(prompt))
print(llama2_output(result))
>> 스프라이트 제로는 0kcal이며, 칼로리 걱정 없이 즐길 수 있는 음료입니다.
Langchain과 연결
- 이전에 강연했던 "[AIF] 챗GPT 점메추 메뉴판, 예산입력하고 점심 메뉴 추천받자"의 코드에 적용한 내용입니다.
def bkchain_output(text):
text = text.split('Machine:')[1] if 'Machine:' in text else text
text = text.split('Human:')[0] if 'Human:' in text else text
return text.strip()
system_template="""To answer the question at the end, use the following context. If you don't know the answer, just say you don't know and don't try to make up an answer.
I want you to act as my Burger King menu recommender. It tells you your budget and suggests what to buy.
please answer in korean.
{summaries}
"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
bk_chain = RetrievalQAWithSourcesChain.from_chain_type(
llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True,
reduce_k_below_max_tokens=True
)
result = bk_chain({"question": "스모키 바비큐 X의 독특한 특징은 무엇인가요?"})
print(f"질문 : {result['question']}")
print()
print(f"답변 : {bkchain_output(result['answer'])}")
>> 질문 : 스모키 바비큐 X의 독특한 특징은 무엇인가요?
답변 : 스모키 바비큐 X는 강력한 불향을 입혀 더욱 깊은 풍미를 느낄 수 있습니다. 또한 육즙 가득한 쇠고 기 패티 위에 아낌없이 뿌려진 모짜렐라 치즈의 환상의 조합을 자랑합니다.
Gradio chat UI
- fine-tune을 수행한 llm을 Gradio를 통한 chat 예제
import os
import logging
import sys
import gradio as gr
import torch
import gc
def reset_state():
return [], [], "Reset Done"
def reset_textbox():
return gr.update(value=""),""
def transfer_input(inputs):
textbox = reset_textbox()
return (
inputs,
gr.update(value=""),
gr.Button.update(visible=True),
)
title = """<h1 align="left" style="min-width:350px; margin-top:0;"> <img src="https://lh3.google.com/u/0/d/1txdmhh6pWjdJBpqGBRMdC0qQX2f7pzxI=w2020-h952-iv1" width="32px" style="display: inline"> AIF 버거킹 chat </h1>"""
description_top = """\
<div align="left">
<p></p>
<p>
</p >
</div>
"""
CONCURRENT_COUNT = 100
ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
small_and_beautiful_theme = gr.themes.Soft(
primary_hue=gr.themes.Color(
c50="#02C160",
c100="rgba(2, 193, 96, 0.2)",
c200="#02C160",
c300="rgba(2, 193, 96, 0.32)",
c400="rgba(2, 193, 96, 0.32)",
c500="rgba(2, 193, 96, 1.0)",
c600="rgba(2, 193, 96, 1.0)",
c700="rgba(2, 193, 96, 0.32)",
c800="rgba(2, 193, 96, 0.32)",
c900="#02C160",
c950="#02C160",
),
secondary_hue=gr.themes.Color(
c50="#576b95",
c100="#576b95",
c200="#576b95",
c300="#576b95",
c400="#576b95",
c500="#576b95",
c600="#576b95",
c700="#576b95",
c800="#576b95",
c900="#576b95",
c950="#576b95",
),
neutral_hue=gr.themes.Color(
name="gray",
c50="#f9fafb",
c100="#f3f4f6",
c200="#e5e7eb",
c300="#d1d5db",
c400="#B2B2B2",
c500="#808080",
c600="#636363",
c700="#515151",
c800="#393939",
c900="#272727",
c950="#171717",
),
radius_size=gr.themes.sizes.radius_sm,
).set(
button_primary_background_fill="#06AE56",
button_primary_background_fill_dark="#06AE56",
button_primary_background_fill_hover="#07C863",
button_primary_border_color="#06AE56",
button_primary_border_color_dark="#06AE56",
button_primary_text_color="#FFFFFF",
button_primary_text_color_dark="#FFFFFF",
button_secondary_background_fill="#F2F2F2",
button_secondary_background_fill_dark="#2B2B2B",
button_secondary_text_color="#393939",
button_secondary_text_color_dark="#FFFFFF",
# background_fill_primary="#F7F7F7",
# background_fill_primary_dark="#1F1F1F",
block_title_text_color="*primary_500",
block_title_background_fill="*primary_100",
input_background_fill="#F6F6F6",
)
with open("/content/easy_finetuner/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
)
total_count = 0
def predict(input_text,
history):
global bk_chain
result = bk_chain({"question": input_text})
history = history + [((input_text, None))]
history = history + [((None, bkchain_output(result['answer'])))]
return history, history, "Generate: Success"
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row(scale=1):
chatbot = gr.Chatbot(avatar_images=('https://yt3.googleusercontent.com/_JbQDtNPfI8h6RPW_9Og5qlGhSBhpMp5qX3JR7iNeSC9XZL4btbNE3dFB4ec77tauPA-nLGQTQ=s900-c-k-c0x00ffffff-no-rj', 'https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7'),elem_id="chuanhu_chatbot").style(height="100%")
with gr.Row(scale=1):
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Enter text"
).style(container=False)
with gr.Column(min_width=70, scale=1):
submitBtn = gr.Button("Send")
with gr.Column(min_width=70, scale=1):
cancelBtn = gr.Button("Stop")
with gr.Row(scale=1):
emptyBtn = gr.Button(
"🧹 New Conversation",
)
predict_args = dict(
fn=predict,
inputs=[
user_question,
history
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
reset_args = dict(
fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
)
# Chatbot
transfer_input_args = dict(
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
)
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
gr.Markdown("<h2>버거킹 chat 시연 리스트</h2>")
gr.Examples(
examples=[
"스모키 바비큐 X의 독특한 특징은 무엇인가요?",
"버거킹에서 판매하는 스프라이트 제로의 특징은 무엇인가요?",
"롱치킨버거에는 어떤 소스와 야채가 사용되었나요?"
],
inputs=user_input
)
emptyBtn.click(
reset_state,
outputs=[chatbot, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
demo.queue(concurrency_count=1).launch()
전체 과정 한번에 보기
이번 강연에서는 llama2를 활용한 fine-tune, langchain 연결을 설명했습니다.
해당 강연에서는 기본적인 내용만 빠르게 다루기 위해서 SFTTrainer에 대해서는 깊게 적용하지 않았습니다.
자세한 내용은 Youtube 강연을 확인해 주세요.
감사합니다.
반응형
'Study > Seminar' 카테고리의 다른 글
[AIF] 챗GPT 점메추 메뉴판, 예산입력하고 점심 메뉴 추천받자 (17) | 2024.07.22 |
---|---|
[AIFLD2023] 어쩌다 키포인트 검출 제대로 입문하기 (2) | 2024.07.22 |
[Space-S x DLD 2022] 케라스 실용 예제 및 개발 가이드 컨퍼런스 - Advanced Augmentation Strategy in Keras 리뷰 (1) | 2024.07.22 |
INNOPOLIS AI SPACE-S 인공지능 세미나 - 이미지 분류를 위한 딥러닝 문제해결 패턴 리뷰 (1) | 2024.07.22 |
INNOPOLIS AI SPACE-S 인공지능 세미나 - 정형 데이터를 다루는 머신러닝 문제해결 패턴 리뷰 (2) | 2024.07.22 |