반응형

 

[AIF] 개인 데이터셋을 통한 llama2 fine-tune

  • AIF에서 발표한 내용을 다시 작성한 Post입니다, 해당 강의는 llama2를 사용한 fine-tune의 간편화에 목적을 두었습니다.
  • 강연에서는 Gradientai, SFTTrainer 을 예시로 사용했습니다.
  • https://github.com/lxe/simple-llm-finetuner
    • 해당 github의 simple-llm-finetuner가 llama2 모델을 지원하지 않고, 추가 업데이트 없이 종료되어 해당 github의 코드를 사용하여 easy_finetuner를 만들었습니다. (https://github.com/choijhyeok/easy_finetuner?tab=readme-ov-file)
      • 해당 github의 Trainer를 SFTTrainer로 변경
      • parameter를 Colab의 T4상에서 훈련할수 있도록 변경
      • SFTTraienr parameter, T4 GPU 최적화 setting등 내부 config 정리
      • 필요없는 기능 제거, Gradio에서 데이터셋 미리보기 기능 추가
      • Huggingface의 SSH KEY 입력 추가

현재 Gradientai는 정상 작동하지 않아서 easy_finetuner를 설명하겠습니다.

 

Easy Finetuner 실행

  • 사용한 패키지의 버전은 제 github의 https://github.com/choijhyeok/easy_finetuner 에서 requirements.txt를 참고해 주세요.
  • 해당 github를 git clone하고, 양식에 맞게 데이터 저장을 우선적으로 수행해야 합니다.

 

DATASET 생성

  • 아래 코드와 같이 text 혹은 jsonl등등 text가 담겨있는 데이터를 huggingface의 Dataset형식으로 생성해 주세요.
  • 이때 중요한건 열 이름은 반드시 'text' 하나로 된 단일 열에 "instruction ~ response ~"과 같이 text 형태로 넣어주시면 됩니다.
    • 단일 열로 만든 이유는 dataset마다 열이름이 달라서 하나로 통일하고자 수행했습니다. 기본적으로 SFTTrainer를 사용하게 만들었기 때문에 Instruction Context Question Answer 혹은 Instruction Question Answer 형태로 작업해주시면 됩니다.
bergerking_dataset = []
with jsonlines.open("/content/qa_버거킹_train.jsonl") as f:
    for line in f.iter():
      # bergerking_dataset.append(f'<s>[INST] {line["inputs"]} [/INST] {line["response"]} </s>')
      bergerking_dataset.append(f'<s>### Instruction: \n{line["inputs"]} \n\n### Response: \n{line["response"]}</s>')

# 데이터셋 확인
print('데이터셋 확인')
print(bergerking_dataset[:5])

# 데이터셋 생성 및 저장
burgerking_dataset = Dataset.from_dict({"text": bergerking_dataset})
burgerking_dataset.save_to_disk('/content/easy_finetuner/example-datasets/burgerking_dataset')

# 데이터셋 info 확인
print('데이터셋 info 확인')
print(burgerking_dataset)

 

실행화면 설명

  • SFTTraienr의 parameter를 설정하여 Lora Fine-Tune을 수행

  • Easy Finetuner github의 config에 아래와 같이 설명이 추가되어 있습니다.

    parser.add_argument('--num_train_epochs', type=int, default=1, help='epochs 수')
    parser.add_argument('--learning_rate', type=float, default=2e-4, help='Learning rate')
    parser.add_argument('--output_dir', type=str, default='./result', help='저장위치')
    parser.add_argument('--per_device_train_batch_size', type=int, default=4, help='훈련당 train gpu batch size')
    ...
%cd /content/easy_finetuner
!python app.py --share

 

fine-tune llama2 RAG

  • 학습한 adapter를 base_model에 연결해서 RAG를 수행한 과정
from transformers import AutoTokenizer, LlamaForCausalLM
import torch
import argparse
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import warnings
warnings.filterwarnings('ignore')


def llama2_prompt(input_text):
  return f'### Instruction:\n{input_text}\n\n### Response:'

def llama2_output(ouput_text):
  sep = ouput_text[0]['generated_text'].split('### Response:')[1].split('### Instruction')[0].split('## Instruction')[0].split('# Instruction')[0].split('Instruction')[0]
  sep = sep[1:] if sep[0] == '.' else sep
  sep = sep[:sep.find('.')+1] if '.' in sep else sep
  return sep

adapter_name = 'beomi_llama-2-ko-7b_burgerking-ko'


compute_dtype = getattr(torch, 'float16')

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False
)

model = AutoModelForCausalLM.from_pretrained('beomi/llama-2-ko-7b', quantization_config=bnb_config, device_map={'': 0}, use_auth_token=os.environ["huggingface_token"])
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained('beomi/llama-2-ko-7b', trust_remote_code=True, use_auth_token=os.environ["huggingface_token"])
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = PeftModel.from_pretrained(model, f'/content/easy_finetuner/lora/{adapter_name}')

 

Huggingface Pipeline

  • huggingface의 pipeline을 통해서 추론
pipe = pipeline(task="text-generation",
                model=model,
                tokenizer=tokenizer,
                max_length=150,
                do_sample=True,
                temperature=0.1,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id,
                top_k=3,
                # top_p=0.3,
                repetition_penalty = 1.3,
                framework='pt'
                # early_stopping=True
)

prompt = "버거킹에서 판매하는 스프라이트 제로의 특징은 무엇인가요?"
result = pipe(llama2_prompt(prompt))
print(llama2_output(result))

>> 스프라이트 제로는 0kcal이며, 칼로리 걱정 없이 즐길 수 있는 음료입니다.

 

Langchain과 연결

  • 이전에 강연했던 "[AIF] 챗GPT 점메추 메뉴판, 예산입력하고 점심 메뉴 추천받자"의 코드에 적용한 내용입니다.

def bkchain_output(text):
  text = text.split('Machine:')[1] if 'Machine:' in text else text
  text = text.split('Human:')[0] if 'Human:' in text else text
  return text.strip()



system_template="""To answer the question at the end, use the following context. If you don't know the answer, just say you don't know and don't try to make up an answer.
I want you to act as my Burger King menu recommender. It tells you your budget and suggests what to buy.

please answer in korean.


{summaries}
"""
messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)


chain_type_kwargs = {"prompt": prompt}
bk_chain = RetrievalQAWithSourcesChain.from_chain_type(
    llm,
    chain_type="stuff",
    retriever=retriever,
    chain_type_kwargs=chain_type_kwargs,
    return_source_documents=True,
    reduce_k_below_max_tokens=True
)

result = bk_chain({"question": "스모키 바비큐 X의 독특한 특징은 무엇인가요?"})

print(f"질문 : {result['question']}")
print()
print(f"답변 : {bkchain_output(result['answer'])}")

>> 질문 : 스모키 바비큐 X의 독특한 특징은 무엇인가요?

    답변 : 스모키 바비큐 X는 강력한 불향을 입혀 더욱 깊은 풍미를 느낄 수 있습니다. 또한 육즙 가득한 쇠고    기 패티 위에 아낌없이 뿌려진 모짜렐라 치즈의 환상의 조합을 자랑합니다.

 

Gradio chat UI

  • fine-tune을 수행한 llm을 Gradio를 통한 chat 예제
import os
import logging
import sys
import gradio as gr
import torch
import gc

def reset_state():
    return [], [], "Reset Done"
def reset_textbox():
    return gr.update(value=""),""
def transfer_input(inputs):
    textbox = reset_textbox()
    return (
        inputs,
        gr.update(value=""),
        gr.Button.update(visible=True),
    )

title = """<h1 align="left" style="min-width:350px; margin-top:0;"> <img src="https://lh3.google.com/u/0/d/1txdmhh6pWjdJBpqGBRMdC0qQX2f7pzxI=w2020-h952-iv1" width="32px" style="display: inline"> AIF 버거킹 chat </h1>"""
description_top = """\
<div align="left">
<p></p>
<p>
</p >
</div>
"""

CONCURRENT_COUNT = 100


ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"

small_and_beautiful_theme = gr.themes.Soft(
        primary_hue=gr.themes.Color(
            c50="#02C160",
            c100="rgba(2, 193, 96, 0.2)",
            c200="#02C160",
            c300="rgba(2, 193, 96, 0.32)",
            c400="rgba(2, 193, 96, 0.32)",
            c500="rgba(2, 193, 96, 1.0)",
            c600="rgba(2, 193, 96, 1.0)",
            c700="rgba(2, 193, 96, 0.32)",
            c800="rgba(2, 193, 96, 0.32)",
            c900="#02C160",
            c950="#02C160",
        ),
        secondary_hue=gr.themes.Color(
            c50="#576b95",
            c100="#576b95",
            c200="#576b95",
            c300="#576b95",
            c400="#576b95",
            c500="#576b95",
            c600="#576b95",
            c700="#576b95",
            c800="#576b95",
            c900="#576b95",
            c950="#576b95",
        ),
        neutral_hue=gr.themes.Color(
            name="gray",
            c50="#f9fafb",
            c100="#f3f4f6",
            c200="#e5e7eb",
            c300="#d1d5db",
            c400="#B2B2B2",
            c500="#808080",
            c600="#636363",
            c700="#515151",
            c800="#393939",
            c900="#272727",
            c950="#171717",
        ),
        radius_size=gr.themes.sizes.radius_sm,
    ).set(
        button_primary_background_fill="#06AE56",
        button_primary_background_fill_dark="#06AE56",
        button_primary_background_fill_hover="#07C863",
        button_primary_border_color="#06AE56",
        button_primary_border_color_dark="#06AE56",
        button_primary_text_color="#FFFFFF",
        button_primary_text_color_dark="#FFFFFF",
        button_secondary_background_fill="#F2F2F2",
        button_secondary_background_fill_dark="#2B2B2B",
        button_secondary_text_color="#393939",
        button_secondary_text_color_dark="#FFFFFF",
        # background_fill_primary="#F7F7F7",
        # background_fill_primary_dark="#1F1F1F",
        block_title_text_color="*primary_500",
        block_title_background_fill="*primary_100",
        input_background_fill="#F6F6F6",
    )

with open("/content/easy_finetuner/custom.css", "r", encoding="utf-8") as f:
    customCSS = f.read()

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
)


total_count = 0
def predict(input_text,
            history):
    global bk_chain

    result = bk_chain({"question": input_text})
    history = history + [((input_text, None))]
    history = history + [((None, bkchain_output(result['answer'])))]
    return history, history, "Generate: Success"


with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
    history = gr.State([])
    user_question = gr.State("")
    with gr.Row():
        gr.HTML(title)
        status_display = gr.Markdown("Success", elem_id="status_display")
    gr.Markdown(description_top)
    with gr.Row(scale=1).style(equal_height=True):
        with gr.Column(scale=5):
            with gr.Row(scale=1):
                chatbot = gr.Chatbot(avatar_images=('https://yt3.googleusercontent.com/_JbQDtNPfI8h6RPW_9Og5qlGhSBhpMp5qX3JR7iNeSC9XZL4btbNE3dFB4ec77tauPA-nLGQTQ=s900-c-k-c0x00ffffff-no-rj', 'https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7'),elem_id="chuanhu_chatbot").style(height="100%")
            with gr.Row(scale=1):
                with gr.Column(scale=12):
                    user_input = gr.Textbox(
                        show_label=False, placeholder="Enter text"
                    ).style(container=False)
                with gr.Column(min_width=70, scale=1):
                    submitBtn = gr.Button("Send")
                with gr.Column(min_width=70, scale=1):
                    cancelBtn = gr.Button("Stop")
            with gr.Row(scale=1):
                emptyBtn = gr.Button(
                    "🧹 New Conversation",
                )


    predict_args = dict(
        fn=predict,
        inputs=[
            user_question,
            history
        ],
        outputs=[chatbot, history, status_display],
        show_progress=True,
    )

    reset_args = dict(
        fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
    )

    # Chatbot
    transfer_input_args = dict(
        fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
    )



    predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
    predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)

    gr.Markdown("<h2>버거킹 chat 시연 리스트</h2>")
    gr.Examples(
        examples=[
            "스모키 바비큐 X의 독특한 특징은 무엇인가요?",
            "버거킹에서 판매하는 스프라이트 제로의 특징은 무엇인가요?",
            "롱치킨버거에는 어떤 소스와 야채가 사용되었나요?"
                  ],
        inputs=user_input
    )


    emptyBtn.click(
        reset_state,
        outputs=[chatbot, history, status_display],
        show_progress=True,
    )
    emptyBtn.click(**reset_args)


demo.queue(concurrency_count=1).launch()

 

전체 과정 한번에 보기

이번 강연에서는 llama2를 활용한 fine-tune, langchain 연결을 설명했습니다.

해당 강연에서는 기본적인 내용만 빠르게 다루기 위해서 SFTTrainer에 대해서는 깊게 적용하지 않았습니다.

자세한 내용은 Youtube 강연을 확인해 주세요.

감사합니다.

반응형