#!/bin/bash
# KODI 9B cassidy training - paste-and-run on a RunPod pod (>=48GB GPU, PyTorch template)
set -e
FUNNEL="https://mini161.tailfc34e4.ts.net"
cd /workspace
echo "=== [1/5] deps ==="
pip install -q huggingface_hub oyaml torchdiffeq safetensors 2>&1 | tail -1
[ -d ai-toolkit ] || git clone -q https://github.com/ostris/ai-toolkit
cd ai-toolkit && git submodule update --init --recursive -q 2>/dev/null || true
pip install -q -r requirements.txt 2>&1 | tail -1
echo "=== [2/5] 9B base + dataset (from funnel) ==="
mkdir -p /workspace/models
[ -s /workspace/models/flux-2-klein-base-9b.safetensors ] || curl -fL "$FUNNEL/klein-9b-base.safetensors" -o /workspace/models/flux-2-klein-base-9b.safetensors
curl -fL "$FUNNEL/cassidy_train.zip" -o /workspace/cassidy_train.zip
rm -rf /workspace/train_data && mkdir -p /workspace/train_data
unzip -qo /workspace/cassidy_train.zip -d /workspace/train_data
echo "imgs: $(find /workspace/train_data -name '*.png' | wc -l)"
echo "=== [3/5] config (9B, dim32, 2500 steps, NO quant - relies on big VRAM) ==="
mkdir -p /workspace/ai-toolkit/config /workspace/output
cat > /workspace/ai-toolkit/config/cassidy_9b.yaml <<YAML
job: extension
config:
  name: cassidy_9b
  process:
    - type: sd_trainer
      training_folder: /workspace/output
      device: cuda:0
      network:
        type: lora
        linear: 32
        linear_alpha: 32
      save:
        dtype: float16
        save_every: 500
        max_step_saves_to_keep: 6
      datasets:
        - folder_path: /workspace/train_data/8_cassidy
          caption_ext: txt
          cache_latents_to_disk: true
          resolution: [768, 1024]
      train:
        batch_size: 1
        steps: 2500
        gradient_accumulation_steps: 1
        train_unet: true
        train_text_encoder: false
        gradient_checkpointing: true
        noise_scheduler: flowmatch
        optimizer: adamw8bit
        lr: 1e-4
        dtype: bf16
        quantize: false
        cache_text_embeddings: true
        unload_text_encoder: true
      model:
        name_or_path: /workspace/models
        arch: flux2_klein_9b
        flux2_klein_te_path: Qwen/Qwen3-8B
        flux2_te_filename: flux-2-klein-base-9b.safetensors
        quantize: false
      sample:
        sampler: flowmatch
        sample_every: 99999
        prompts: ["cassidy, photo of a woman"]
        neg: ""
        seed: 7777
        steps: 20
meta:
  name: cassidy_9b
YAML
echo "=== [4/5] TRAIN (will pull Qwen3-8B from HF first; ~2-3h) ==="
cd /workspace/ai-toolkit
export HF_TOKEN="${HF_TOKEN:-}"
python3 run.py config/cassidy_9b.yaml
echo "=== [5/5] DONE ==="
echo "LORA(S):"; ls -la /workspace/output/cassidy_9b/*.safetensors
echo ">>> Download /workspace/output/cassidy_9b/cassidy_9b.safetensors via the RunPod file browser <<<"
