Fine-tuning LLM on AMD GPU

I’ve been fascinated by open-source LLM models and have been running them locally. I like to maintain full control of ML models that run instead of relying on the cloud, simply because it’s more fun that way. From my previous posts, you may know that I use AMD GPU on my Arch Linux (btw), so I will continue this trend of struggling to get things working on my AMD GPU. My most recent project is an attempt to recreate the character Frieren from Frieren: Beyond Journey’s End. Why Frieren in particular? That’s because it appears none of the open-source LLM models seem to be aware of this series at all. So any new behavior added can be attributed to what I did instead of the base model’s knowledge. Furthermore, Frieren has become a very successful series internationally, which makes it somewhat easier to collect a lot of materials from the internet for training purposes. I’ve also binge-read its manga, so I can validate the model’s correctness. To accomplish this, I have two milestones:

I’ve decided to break it down into two separate posts since they cover different techniques and many quirks. This post is focused on the first part, fine-tuning for style.

Frieren

Objective

Alright, I will be honest. What I really want to accomplish is to create a bot that mimics myself so I can set it up to reply to my coworkers on Slack. Just like that episode in HBO’s Silicon Valley. To be fair, I’m sure I’m not the only one who’s trying to do this. The recent trend of AI agents is heading in that direction. But I want to do it completely locally on my home computer. Before I start working on it, I wanted to do a proof of concept project to see how feasible it is. So I decided to create Frieren, a fictional character unknown to open-source LLM models, along with her relevant knowledge. I’ve been following Matt Williams’s YouTube channel, which has been enormously informative in my research.

19 Tips to Better AI fine-tuning

It’s generally understood that fine-tuning is a way to train for style and focus while RAG is the best way to bring in new knowledge to a LLM. To create Frieren, I will use both techniques. So let’s get started on fine-tuning first to mimic Frieren’s style of speech.

Hardware

Since my last post, I’ve actually made some major upgrades to my computer hardware. Most of its core parts are new.

Component Part name
CPU AMD Ryzen 7 9800X3D
GPU AMD Radeon RX 7900 XTX 24GB
RAM 64GB (2 x 32GB) DDR5 6000Mhz
Storage Samsung 980 Pro 1TB

Here is a picture for good measure, since I’m also a build-a-pc enthusiast. RGB is essential of course.

Hardware

I realized that AM4 CPU coolers are not always compatible with AM5. If your AM4 CPU cooler requires a backplate then it won’t work on AM5, because AM5 has a fixed backplate that can’t be removed. On a different note, I swear I’m not actually an AMD fanboy. I was seriously considering picking up either a used RTX 3090 or a Radeon RX 7900 XTX, both of which were about the same price. But RX 7900 XTX has somewhat better rasterized performance, so I figured it’s a better value (for gaming, not everything else) and AMD is generally more Linux friendly when it comes to drivers. So I went AMD again. In case you wonder what games I play. It’s Minecraft Bedrock Edition. So this brings us here. Why do I insist on working with ML on AMD GPU? Why not just use Google Colab or RunPod at least for training? Am I just a masochist? The answer is yes. Yes, I am.

Installation and Environment Setup

To get inference working is surprisingly easy this time. I just installed ollama-rocm-git from AUR and it worked. Great! For UI, I chose to use open-webui, which was also pretty easy to install. I just installed it from pip. In case you need some commands for reference, here it is:

yay -S ollama-rocm-git # assuming you have yay on your Arch Linux
mkdir open-webui
cd open-webui
python3.11 -m venv venv # It's very important that you *don't* use Python 3.13, which has a backwards incompatible issue with C binary interop. Most of the libraries won't work with Python 3.13 as of writing.
source venv/bin/activate
pip install open-webui
open-webui serve

On a different terminal, run

ollama serve

Again on a different terminal.

ollama pull <pick a model from https://ollama.com/library>

Now open http://localhost:8080 and you should have inference ready to go.

The tricky part is to get things set up for training. AMD actually has ROCm variants of most of the necessary libraries, though at varying quality. I was following along AMD’s own guide that they published recently: Fine-tuning and inference using a single accelerator, which was enormously helpful in getting things started. However, there were two open questions that I had:

That’s when I discovered that unsloth can help with both. It can normalize training data format for a specific model template and help you merge LoRA adapter into a gguf file that ollama can then pick up. So I decided to give unsloth a try. Unfortunately, it looks like unsloth does not have AMD GPU support. However according to GitHub user sayanmndl21 most of the dependencies already have ROCm variants and they were able to get it to work with a small patch in unsloth’s code. So I decided to give it a try. Turns out it doesn’t work for me, because the ROCm variant of xformers only supports AMD’s workstation GPUs and not consumer GPUs like my RX 7900 XTX. Hopefully, this will change in the future, because as far as I can see, xformers is the only dependency blocking unsloth from working. If I were to remove xformers from unsloth’s code, training would start and complete, but produce incorrect outputs. I attempted to replace it with SDP, but it was not a simple change and stretched beyond my domain knowledge. I believe Axolotl does support AMD GPU, but looking at Matt Williams’ video, it appears quite quirky to get it to work. Perhaps I will try it in the future. But for now, I’m following AMD’s guide and borrowing just a few helper functions from unsloth, mainly dataset normalization and gguf output. Here is how I got dependencies installed.

cd git # or wherever you keep your git repositories
git clone https://github.com/unslothai/unsloth.git
git clone --recurse https://github.com/ROCm/bitsandbytes.git
cd unsloth
python3.11 -m venv venv # Again, don't use Python 3.13
source venv/bin/activate
pip install torch --index-url https://download.pytorch.org/whl/rocm6.2.4 # As of writing I have ROCm 6.2 on my system
cd ../bitsandbytes
git checkout rocm_enabled_multi_backend
pip install -r requirements-dev.txt
cmake -DBNB_ROCM_ARCH="gfx1100" -DCOMPUTE_BACKEND=hip -S . # Find your GPU architecture from tools like amdgpu_top or rocminfo. Mine is gfx1100 = RX 7900 XTX
make # lol AMD's guide forgot this step, which means none of the binaries are actually built, resulting in some confusing errors later
python setup.py install
cd ../unsloth
pip install unsloth-zoo

If you run into errors when building bitsandbytes, make sure all necessary HIP-related libraries are installed. Pay close attention to what the error messages say and install the libraries that it’s looking for using yay or pacman. If you want to train in a different directory, just put export PYTHONPATH=~/git/unsloth/ in your environment.

Training Dataset Collection

We want to train for Frieren’s style of speech, so what we need to collect is Frieren’s dialogues. More specifically in Q&A format, where one character speaks to Frieren and she replies directly to it. I found an anime transcript of Frieren with a bit of Googling and re-watched the first 7 episodes of Frieren while extracting her dialogues into a plain text file, which I named frieren_dialogue.txt. This turned out to be an extremely time-consuming and labor-intensive task. I managed to collect 244 dialogues before giving up at the end of episode 7. Here is a snippet of it. Due to copyright concerns, I’m not going to share the entire raw file.

We'll have to look for work once we're back.
You're already thinking about that?
We've defeated the Demon King, but it's not over. We have the whole rest of our lives ahead of us.
I suppose you're right.
Frieren, the life ahead of you will surely be much longer than we can imagine.
Perhaps.
The king's going to erect statues of us in the plaza. I'm not sure they'll be able to faithfully recreate my handsome looks, though.
How self-serving of him. He only gave us ten copper coins when we left on our adventure.
...

This text file is pretty simple. You can consider every odd line to be a user prompt and every even line to be a response to the previous line. Next, we need to convert it to ShareGPT format so that unsloth can then normalize for a selected model template. This needs to be a JSONL file, which just means each line is one JSON object. Apparently, the reason why they do this instead of just having a JSON array is so that a very large data file can be easily parsed, sampled, and split without having to load the whole file into memory. I guess that’s fair. I wrote a Python script to do this.

tosharegpt.py

#!/usr/bin/python
import argparse
import json

if __name__ == "__main__":
  parser = argparse.ArgumentParser(
    description="Script to convert alternating text dialogue to ShareGPT JSONL format"
  )
  parser.add_argument("input", help="Input text file")
  parser.add_argument("output", help="Output JSONL file")
  args = parser.parse_args()

  with open(args.input) as f:
    lines = [l.strip() for l in f]

  conversations = []
  cur = []
  for i, l in enumerate(lines):
    if i % 2 == 0:
      cur = [{"from": "human", "value": l}]
    else:
      cur.append({"from": "gpt", "value": l})
      conversations.append(cur)

  with open(args.output, "w") as f:
    for l in conversations:
      j = json.dumps({"conversations": l})
      f.write(j)
      f.write("\n")

It’s pretty straightforward to use. Run python tosharegpt.py frieren_dialogue.txt data.jsonl and you will get a JSONL file that should look like this:

{"conversations": [{"from": "human", "value": "We'll have to look for work once we're back."}, {"from": "gpt", "value": "You're already thinking about that?"}]}
{"conversations": [{"from": "human", "value": "We've defeated the Demon King, but it's not over. We have the whole rest of our lives ahead of us."}, {"from": "gpt", "value": "I suppose you're right."}]}
{"conversations": [{"from": "human", "value": "Frieren, the life ahead of you will surely be much longer than we can imagine."}, {"from": "gpt", "value": "Perhaps."}]}
{"conversations": [{"from": "human", "value": "The king's going to erect statues of us in the plaza. I'm not sure they'll be able to faithfully recreate my handsome looks, though."}, {"from": "gpt", "value": "How self-serving of him. He only gave us ten copper coins when we left on our adventure."}]}
...

LoRA fine-tuning

Let’s get training started. I picked Microsoft’s Phi-4 as base model. The reason is simply that it’s small enough to fit in my VRAM during training but also large enough to produce pretty reasonable responses most of the time. The great thing about using unsloth is that it will help normalize the data template for you, so if you were to pick a different base mode, you can do it quite easily. To get my training script working, I referenced AMD’s guide as well as unsloth’s Phi-4 notebook.

train.py

from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
from unsloth.chat_templates import get_chat_template
from unsloth.chat_templates import standardize_sharegpt
from unsloth.chat_templates import train_on_responses_only
from unsloth.save import unsloth_save_pretrained_gguf

base_model_name = "unsloth/phi-4-bnb-4bit"

# Load base model to GPU memory.
device = "cuda:0"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code = True).to(device)

# Load tokenizer.
tokenizer = AutoTokenizer.from_pretrained(
        base_model_name,
        trust_remote_code = True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "phi-4",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [
        tokenizer.apply_chat_template(
            convo, tokenize = False, add_generation_prompt = False
        )
        for convo in convos
    ]
    return { "text" : texts, }

dataset = load_dataset("json", data_files="data.jsonl", split="train")

dataset = standardize_sharegpt(dataset)
dataset = dataset.map(
    formatting_prompts_func,
    batched=True,
)

peft_config = LoraConfig(
    lora_alpha = 16,
    lora_dropout = 0.1,
    r = 64,
    bias = "none",
    task_type = "CAUSAL_LM",
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

trainer = SFTTrainer(
    model = base_model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    peft_config=peft_config,
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 4,
        max_steps = -1,
        learning_rate = 1e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3820,
        output_dir = "outputs",
        report_to = "none"
    ),
)

trainer = train_on_responses_only(
    trainer,
    instruction_part="<|im_start|>user<|im_sep|>",
    response_part="<|im_start|>assistant<|im_sep|>",
)

trainer_stats = trainer.train()

unsloth_save_pretrained_gguf(trainer.model, "ggufmodel", tokenizer, quantization_method="q4_k_m")

Let’s walk through some key points. We start by loading unsloth’s version of 4bit Phi-4. You can find unsloth’s mapping here. When I loaded the original microsoft/phi-4, it just blew up my VRAM. So I’m gonna stick with the 4bit version. Next, we use unsloth’s get_chat_template and standardize_sharegpt to normalize our training data for Phi-4. Then we configure SFTTrainer for LoRA fine-tuning. Here are some of the notable parameters to watch out for. Keep in mind that my attempt to explain these parameters is extremely layman. It’s always a good idea to look up full documentation to learn more.

Before we start training, I choose to use unsloth’s train_on_responses_only to only fine-tune on responses instead of both prompt and response. I feel this is important, because our goal is to only mimic Frieren’s style of speech and not the other characters’. However, this function seems to cause a bug with per_device_train_batch_size, where the trainer will produce this error.

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length.

What this means is that the trainer thinks entries in the dataset are not the same length, hence it is unable to batch them. It is asking you to either add padding or truncate them. That is strange, because I thought we already enabled padding in the tokenizer. Taking a look at its source code, it appears to be doing some pretty hacky stuff that overwrites all instructions to -100. Given that, I’m not sure if there’s an easy way to patch this, so I will just set per_device_train_batch_size to 1 to work around this issue.

Finally we can get training started with trainer_stats = trainer.train() and we borrow unsloth’s unsloth_save_pretrained_gguf to save the results in gguf format with Q4_K_M quantization. Run this script with unsloth’s venv activated.

python train.py

On my machine, this takes 244 steps and about 10 minutes to finish while consuming 16.6GB of VRAM, which is not bad at all.

Training stats

Once it finishes, you should find ggufmodel directory created. Update ggufmodel/Modelfile to use unsloth.Q4_K_M.gguf instead. Here is my Modelfile for instance

FROM /home/initialxy/tmp/ggufmodel/unsloth.Q4_K_M.gguf
TEMPLATE """<|im_start|><|system|><|im_sep|><|im_end|><|im_start|><|user|><|im_sep|><|im_end|><|im_start|><|assistant|><|im_sep|><|im_end|>"""
PARAMETER stop "<|im_end|>"
PARAMETER stop "<|im_start|>"
PARAMETER stop "<|im_sep|>"
PARAMETER temperature 1.5
PARAMETER min_p 0.1

Now we import this model into ollama. Start ollama with ollama serve in another terminal then run.

cd ggufmodel
ollama create initialxy/frieren # or name it whatever you want

Start open-webui normally and we can see a new model initialxy/frieren:latest is available. Let’s do a bit of a test.

First test

Not bad. Sounds kind of like Frieren. Again, note that I deliberately underfit this model. Try to change your training steps with num_train_epochs or max_steps to get it to a point where you think it sounds right for you. It takes some trial and error to get it right. Next, I created a character with some custom parameters and system prompts in open-webui.

Frieren character

I chose to use 0.6 as temperature and increased its context length to 80k. Keep in mind that open-webui almost always defaults context length to 2k, and Phi-4 is supposed to use 16k context length. Now let’s test it again.

Second test

Not bad, I like it.

Closing

Try to play with the above-mentioned parameters and see where it suits you best. Also try to collect as much high-quality training data as possible, which will also improve quality. In the next post, I will discuss how we can combine it with a RAG in order to supplement this model with world knowledge from Frieren: Beyond Journey’s End in open-webui. It’s not as straightforward as I hoped and requires some hacky monkey patching.

Tags