{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!-- ## Internal Notes: to be deleted.\n",
    "\n",
    "1 TODO: Let's implement a `return_full_text` field so the user can demand a model\n",
    "does not include the the input text as well in its response\n",
    "see https://huggingface.co/docs/transformers/v4.17.0/main_classes/pipelines\n",
    "\n",
    "2 pip installing Oumi with [.gpu] it does not include ipywidgets which disables the monitoring of \n",
    "tqdm inside the notebook and results below in: `TqdmWarning: IProgress not found. Please update jupyter and ipywidgets`\n",
    "Handling it with `!pip install ipywidgets`, TODO: Can we do better?\n",
    "\n",
    "\n",
    "!pip install ipywidgets # Installing ipywidgets for widget visualization -->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetuning a Vision-Language Model (Overview)\n",
    "\n",
    "In this tutorial, we'll use LoRA training and SFT to guide a large vision/language model to produce short and concise answer grounded on visual input.\n",
    "\n",
    "Specifically, we'll use the Oumi framework to streamline the process and achieve high-quality results fast.\n",
    "\n",
    "We'll cover the following topics:\n",
    "1. Prerequisites\n",
    "2. Data Preparation & Sanity Checks\n",
    "3. Training Config Preparation\n",
    "4. Launching Training\n",
    "5. Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 📋 Prerequisites\n",
    "\n",
    "## Machine requirements\n",
    "\n",
    "A machine with CUDA support and a GPU with the minimum of 24GB of memory is required to run this notebook. This notebook thus can not be run on the free Colab tier, which only has a T4 GPU with 16GB memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import torch\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"CUDA version: {torch.version.cuda}\")\n",
    "    print(f\"Number of GPUs: {torch.cuda.device_count()}\")\n",
    "    print(f\"GPU type: {torch.cuda.get_device_name()}\")\n",
    "    total_memory_gb = float(torch.cuda.mem_get_info()[1]) / float(1024 * 1024 * 1024)\n",
    "    print(f\"GPU memory: {total_memory_gb:.1f}GB\")\n",
    "    if total_memory_gb < 24.0 * 0.99:\n",
    "        print(\n",
    "            \"Error! The notebook requires at least 24GB of memory. \"\n",
    "            \"Got: {total_memory_gb:.3f}GB\",\n",
    "            file=sys.stderr,\n",
    "        )\n",
    "    elif total_memory_gb < 30.0 * 0.99:\n",
    "        print(\n",
    "            \"You may have to reduce batch size to 1 for LoRA fine-tuning \"\n",
    "            \"to prevent CUDA OOM (out-of-memory) errors.\\n\"\n",
    "            f\"Your GPU got only {total_memory_gb:.1f}GB of VRAM.\",\n",
    "            file=sys.stderr,\n",
    "        )\n",
    "else:\n",
    "    print(\n",
    "        \"Error! The notebook will NOT run in a machine without CUDA.\", file=sys.stderr\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Oumi Installation\n",
    "\n",
    "First, let's install Oumi. You can find more detailed instructions [here](https://oumi.ai/docs/en/latest/get_started/installation.html). Here, we include Oumi's GPU dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install oumi[gpu]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Additionally, install the following packages for widget visualization.\n",
    "%pip install ipywidgets\n",
    "\n",
    "# And deactivate the parallelism warning from the tokenizers library.\n",
    "import os\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # Deactivate relevant HF warnings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure HuggingFace Access Token\n",
    "\n",
    "Llama models are gated on HuggingFace Hub. To run this notebook, you must first complete the agreement for [Llama 3.2](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) on HuggingFace, and wait for it to be accepted. Then, specify `HF_TOKEN` below to enable access to the model if it's not already set.\n",
    "\n",
    "Usually, you can get the token by running this command `cat ~/.cache/huggingface/token` on your local machine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.environ.get(\"HF_TOKEN\"):\n",
    "    # NOTE: Set your Hugging Face token here if not already set.\n",
    "    os.environ[\"HF_TOKEN\"] = \"<MY_HF_TOKEN>\"\n",
    "\n",
    "hf_token = os.environ.get(\"HF_TOKEN\")\n",
    "print(f\"Using HF Token: '{hf_token}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating our working directory\n",
    "\n",
    "For our experiments, we'll use the following folder to save the model, training artifacts, and our inference and training configs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using the directory: '/home/user/oumi/notebooks/vision_language_tutorial'\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "tutorial_dir = Path(\"vision_language_tutorial\").resolve()\n",
    "tutorial_dir.mkdir(parents=True, exist_ok=True)\n",
    "tutorial_dir = str(tutorial_dir)  # Convert back to `str` for simplicity.\n",
    "print(f\"Using the directory: '{tutorial_dir}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In what follows we use Meta's [Llama-3.2-11B-Vision-Instruct](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) model.\n",
    "\n",
    "Llama-11B-Vision is a high-performing instruction-tuned multi-modal model, which uses a moderate amount of resources (11B parameters).\n",
    "\n",
    "We will finetune this model with the [vqav2-small](https://huggingface.co/datasets/merve/vqav2-small) dataset which will help the model respond in __a succinct manner__ on visually grounded questions.\n",
    "\n",
    "The principles presented here are generic and \"Oumi-flexible\". \n",
    "\n",
    "To repeat this experiment with other models/data you can simply replace e.g., the `model_name` (a string) with the names of other supported models (see [here](https://oumi.ai/docs/en/latest/resources/models/supported_models.html)) and adapt the configurations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First, let's initialize our dataset and build a tokenizer and an underlying data processor.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-01-28 16:02:56,348][oumi][rank0][pid:865958][MainThread][INFO]][models.py:437] Using the chat template 'llama3-instruct', which is the default for model 'meta-llama/Llama-3.2-11B-Vision-Instruct'.\n",
      "[2025-01-28 16:02:56,350][oumi][rank0][pid:865958][MainThread][INFO]][base_map_dataset.py:68] Creating map dataset (type: Vqav2SmallDataset) dataset_name: 'None', dataset_path: 'None'...\n",
      "[2025-01-28 16:02:59,913][oumi][rank0][pid:865958][MainThread][INFO]][base_map_dataset.py:472] Dataset Info:\n",
      "\tSplit: validation\n",
      "\tVersion: 0.0.0\n",
      "\tDataset size: 3391008667\n",
      "\tDownload size: 3376516283\n",
      "\tSize: 6767524950 bytes\n",
      "\tRows: 21435\n",
      "\tColumns: ['multiple_choice_answer', 'question', 'image']\n",
      "[2025-01-28 16:03:01,259][oumi][rank0][pid:865958][MainThread][INFO]][base_map_dataset.py:411] Loaded DataFrame with shape: (21435, 3). Columns:\n",
      "multiple_choice_answer    object\n",
      "question                  object\n",
      "image                     object\n",
      "dtype: object\n",
      "\n",
      "Examples included: 1000\n"
     ]
    }
   ],
   "source": [
    "from oumi.builders import build_tokenizer\n",
    "from oumi.core.configs import ModelParams\n",
    "from oumi.datasets.vision_language.vqav2_small import Vqav2SmallDataset\n",
    "\n",
    "model_name = \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
    "\n",
    "tokenizer = build_tokenizer(ModelParams(model_name=model_name))\n",
    "\n",
    "dataset = Vqav2SmallDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    processor_name=model_name,\n",
    "    limit=1000,  # Limit the number of examples for demonstration purposes (!)\n",
    ")\n",
    "\n",
    "print(\"\\nExamples included:\", len(dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now let's see a few examples to get a feel for the dataset we are going to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "from oumi.core.types.conversation import Type\n",
    "\n",
    "num_examples_to_display = 3\n",
    "\n",
    "for i in range(num_examples_to_display):\n",
    "    conversation = dataset.conversation(i)  # Retrieve the i-th example (conversation)\n",
    "\n",
    "    print(f\"Example {i}:\")\n",
    "\n",
    "    for message in conversation.messages:\n",
    "        if message.role == \"user\":  # The `user` poses a question, regarding an image\n",
    "            img_content = message.image_content_items[-1]\n",
    "            assert img_content.binary is not None\n",
    "            image = Image.open(io.BytesIO(img_content.binary))\n",
    "            image.save(f\"{tutorial_dir}/example_{i}.png\")  # Save the image locally\n",
    "            display(image)\n",
    "\n",
    "        print(f\"{message.role}: {message.content}\")\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see above the ground-truth answers are **very short and succinct**, which can be an advantage for scenarios where we want to generate concise answers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>multiple_choice_answer</th>\n",
       "      <th>question</th>\n",
       "      <th>image</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>carnival ride</td>\n",
       "      <td>Where are the kids riding?</td>\n",
       "      <td>{'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>yes</td>\n",
       "      <td>Is this boy a good pitcher?</td>\n",
       "      <td>{'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>wetsuit</td>\n",
       "      <td>What is the person wearing?</td>\n",
       "      <td>{'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>How many sinks are in this bathroom?</td>\n",
       "      <td>{'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>soccer</td>\n",
       "      <td>What sport are the girls playing?</td>\n",
       "      <td>{'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  multiple_choice_answer                              question  \\\n",
       "0          carnival ride            Where are the kids riding?   \n",
       "1                    yes           Is this boy a good pitcher?   \n",
       "2                wetsuit           What is the person wearing?   \n",
       "3                      4  How many sinks are in this bathroom?   \n",
       "4                 soccer     What sport are the girls playing?   \n",
       "\n",
       "                                               image  \n",
       "0  {'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...  \n",
       "1  {'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...  \n",
       "2  {'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...  \n",
       "3  {'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...  \n",
       "4  {'bytes': b'\\xff\\xd8\\xff\\xe0\\x00\\x10JFIF\\x00\\x...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Furthermore, if you want to see directly the underlying stored data, stored in a\n",
    "# pandas DataFrame, you can do so by running the following command:\n",
    "dataset.data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initial Model Responses\n",
    "\n",
    "Let's see now how this model performs on a given prompt without any finetuning.\n",
    "- For this we will create and execute and `inference configuration` stored in a YAML file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile \"{tutorial_dir}/infer.yaml\"\n",
    "\n",
    "model:\n",
    "  model_name: \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
    "  torch_dtype_str: \"bfloat16\" # Good choice if you have access to Ampere or newer GPU\n",
    "  chat_template: \"llama3-instruct\"\n",
    "  model_max_length: 1024\n",
    "  trust_remote_code: False # For other models this might need to be set to True\n",
    "  \n",
    "generation:\n",
    "  max_new_tokens: 128\n",
    "  batch_size: 1\n",
    "  \n",
    "engine: NATIVE \n",
    "# Let's use the `native` engine (i.e., the underlying machine's default)\n",
    "# for inference.  \n",
    "# You can also consider VLLM, if are working with GPU for much faster inference. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-01-28 16:03:04,292][oumi][rank0][pid:865958][MainThread][INFO]][models.py:185] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...\n",
      "[2025-01-28 16:03:04,292][oumi][rank0][pid:865958][MainThread][INFO]][models.py:255] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForVision2Seq'> to instantiate model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n",
      "WARNING:accelerate.utils.modeling:The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3529d0fc54844a3791a01803b08e52bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-01-28 16:03:09,567][oumi][rank0][pid:865958][MainThread][INFO]][models.py:428] Using the chat template 'llama3-instruct' specified in model config!\n",
      "[2025-01-28 16:03:11,495][oumi][rank0][pid:865958][MainThread][INFO]][native_text_inference_engine.py:111] Setting EOS token id to `128009`\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[USER: <IMAGE_BINARY> | Is this boy a good pitcher?\n",
       " ASSISTANT: The boy in the image is wearing a baseball uniform and appears to be pitching, but it's difficult to determine if he's a good pitcher based on this image alone.]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from oumi.core.configs import InferenceConfig\n",
    "from oumi.core.types.conversation import Conversation, Message, Role\n",
    "from oumi.inference import NativeTextInferenceEngine\n",
    "\n",
    "# Note: the *first* time you call inference will take a few minutes to download\n",
    "# and cache the model (assuming you do not already have it downloaded locally).\n",
    "inference_config = InferenceConfig.from_yaml(str(Path(tutorial_dir) / \"infer.yaml\"))\n",
    "inference_engine = NativeTextInferenceEngine(inference_config.model)\n",
    "\n",
    "example = dataset.conversation(1)\n",
    "example = Conversation(messages=example.filter_messages(role=Role.USER))\n",
    "inference_engine.infer([example], inference_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clean up to free-up GPU memory used for inference above\n",
    "import gc\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "def cleanup_memory():\n",
    "    \"\"\"Delete the inference_engine and collect garbage.\"\"\"\n",
    "    global inference_engine\n",
    "    if inference_engine:\n",
    "        del inference_engine\n",
    "        inference_engine = None\n",
    "    for _ in range(3):\n",
    "        gc.collect()\n",
    "        torch.cuda.empty_cache()\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "\n",
    "cleanup_memory()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Is this boy a good pitcher?\n",
      "\n",
      "@@@@@@@@@@@@@@@@@@@\n",
      "@                 @\n",
      "@   @@@@@  @  @   @\n",
      "@   @   @  @  @   @\n",
      "@   @@@@@  @@@@   @\n",
      "@                 @\n",
      "@   @@@@@@@   @   @\n",
      "@   @  @  @   @   @\n",
      "@   @  @  @   @   @\n",
      "@                 @\n",
      "@@@@@@@@@@@@@@@@@@@\n",
      "\n",
      "[2025-01-28 16:39:46,835][oumi][rank0][pid:876355][MainThread][INFO]][models.py:185] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...\n",
      "[2025-01-28 16:39:47,027][oumi][rank0][pid:876355][MainThread][INFO]][models.py:255] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForVision2Seq'> to instantiate model.\n",
      "INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n",
      "WARNING:accelerate.utils.modeling:The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
      "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:06<00:00,  1.34s/it]\n",
      "[2025-01-28 16:39:54,531][oumi][rank0][pid:876355][MainThread][INFO]][models.py:428] Using the chat template 'llama3-instruct' specified in model config!\n",
      "Enter your input prompt: [2025-01-28 16:39:56,469][oumi][rank0][pid:876355][MainThread][INFO]][native_text_inference_engine.py:111] Setting EOS token id to `128009`\n",
      "------------\n",
      "USER: <IMAGE_BINARY> | Is this boy a good pitcher?\n",
      "ASSISTANT: The boy in the image is wearing a baseball uniform and appears to be pitching, but it's difficult to determine if he's a good pitcher based on this image alone.\n",
      "------------\n",
      "\n",
      "Enter your input prompt: \n",
      "Exiting...\n"
     ]
    }
   ],
   "source": [
    "# Note. You can do the same inference directly with our CLI (terminal) instead of the\n",
    "# Python API. E.g., uncomment the following line and execute this cell:\n",
    "\n",
    "conversation_id = 1\n",
    "query = dataset.conversation(conversation_id).messages[0].text_content_items[0].content\n",
    "print(f\"\\n{query}\")\n",
    "image_file = f\"{tutorial_dir}/example_{conversation_id}.png\"\n",
    "\n",
    "!echo \"{query}\" | oumi infer -c \"{tutorial_dir}/infer.yaml\" -i --image=\"{image_file}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OK! As you can see by default this model gives quite __verbose__ responses. Can we change this behavior?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing our training experiment\n",
    " - Specifically, let's create an execute a YAML file with our _training_ config!\n",
    " - You can find many more details about the listed hyper-parameters in our [docs](https://oumi.ai/docs/en/latest/user_guides/train/training_methods.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting /home/user/oumi/notebooks/vision_language_tutorial/train.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile \"{tutorial_dir}/train.yaml\"\n",
    "\n",
    "model:\n",
    "  model_name: \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
    "  torch_dtype_str: \"bfloat16\"\n",
    "  model_max_length: 1024  \n",
    "  attn_implementation: \"sdpa\"\n",
    "  chat_template: \"llama3-instruct\"\n",
    "  freeze_layers:\n",
    "    - \"vision_model\"     # Let's finetune only the language component of the model\n",
    "\n",
    "data:\n",
    "  train:\n",
    "    collator_name: \"vision_language_with_padding\" # Simple padding collator\n",
    "    use_torchdata: True\n",
    "\n",
    "    datasets:\n",
    "      - dataset_name: \"merve/vqav2-small\"\n",
    "        split: \"validation\" # This dataset has only a validation split\n",
    "        shuffle: True\n",
    "        seed: 42\n",
    "        transform_num_workers: \"auto\"\n",
    "        dataset_kwargs:\n",
    "          # The default for our model:\n",
    "          processor_name: \"meta-llama/Llama-3.2-11B-Vision-Instruct\"           \n",
    "          limit: 1000 # Again, we downsample to 1000 examples for demonstration \n",
    "                      # purposes only.\n",
    "          return_tensors: True      \n",
    "\n",
    "training:\n",
    "  output_dir: \"vision_language_tutorial\"\n",
    "  trainer_type: \"TRL_SFT\"\n",
    "  enable_gradient_checkpointing: True\n",
    "  # You can decrease the following two params if you run out of memory\n",
    "  per_device_train_batch_size: 2 # Use batch size 1 if you only have 24GB of GPU VRAM.\n",
    "  gradient_accumulation_steps: 8 # Thus effective batch size is 2x8=16 on a single GPU\n",
    "  use_peft: True\n",
    "  \n",
    "  # **NOTE**\n",
    "  # We set `max_steps` to 10 steps to first verify that training works\n",
    "  # Swap to `num_train_epochs: 1` to get more meaningful results\n",
    "  # (One training epoch will take ~25 mins on a single A100-40GB GPUs)\n",
    "  max_steps: 40\n",
    "  # num_train_epochs: 1\n",
    "\n",
    "  gradient_checkpointing_kwargs:\n",
    "    # Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint\n",
    "    use_reentrant: False\n",
    "  ddp_find_unused_parameters: False\n",
    "  empty_device_cache_steps: 1\n",
    "\n",
    "  optimizer: \"adamw_torch_fused\"\n",
    "  learning_rate: 2e-5\n",
    "  warmup_ratio: 0.03\n",
    "  weight_decay: 0.0\n",
    "  lr_scheduler_type: \"cosine\"\n",
    "\n",
    "  logging_steps: 5\n",
    "  save_steps: 0\n",
    "  dataloader_main_process_only: False\n",
    "  dataloader_num_workers: \"auto\"\n",
    "  dataloader_prefetch_factor: 16  \n",
    "  enable_wandb: True # Set to False if you don't want to use Weights & Biases\n",
    "  \n",
    "peft: # Our LoRA configuration; we target several layers  \n",
    "  lora_r: 8\n",
    "  lora_alpha: 8\n",
    "  lora_dropout: 0.1\n",
    "  lora_target_modules:\n",
    "    - \"q_proj\"\n",
    "    - \"v_proj\"\n",
    "    - \"o_proj\"\n",
    "    - \"k_proj\"\n",
    "    - \"gate_proj\"\n",
    "    - \"up_proj\"\n",
    "    - \"down_proj\"\n",
    "  lora_init_weights: GAUSSIAN\n",
    "\n",
    "# Below lines are effective if you have access to multiple GPUs\n",
    "# If you do, please uncomment them to train with all available GPUS:\n",
    "\n",
    "# fsdp:\n",
    "#   enable_fsdp: True\n",
    "#   sharding_strategy: \"HYBRID_SHARD\"\n",
    "#   forward_prefetch: True\n",
    "#   auto_wrap_policy: \"TRANSFORMER_BASED_WRAP\"\n",
    "#   transformer_layer_cls: \"MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-claim memory!\n",
    "cleanup_memory()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Let's launch the training!\n",
    "\n",
    "!oumi train -c \"{tutorial_dir}/train.yaml\"\n",
    "\n",
    "# Or, if you have multiple GPUs you want to use:\n",
    "# !oumi distributed torchrun -m oumi train -c \"{tutorial_dir}/train.yaml\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finally, let's use the Fine-tuned Model and see the effect of training!\n",
    "\n",
    "Once we're happy with the results, we can serve the fine-tuned model for inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting /home/user/oumi/notebooks/vision_language_tutorial/trained_infer.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile \"{tutorial_dir}/trained_infer.yaml\"\n",
    "\n",
    "model:\n",
    "  model_name: \"meta-llama/Llama-3.2-11B-Vision-Instruct\"  \n",
    "  adapter_model: \"vision_language_tutorial\" # Directory with our saved LoRA parameters!\n",
    "  torch_dtype_str: \"bfloat16\"\n",
    "  chat_template: \"llama3-instruct\"\n",
    "  model_max_length: 1024\n",
    "  trust_remote_code: False\n",
    "\n",
    "generation:\n",
    "  max_new_tokens: 256\n",
    "  batch_size: 1\n",
    "  \n",
    "engine: NATIVE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup_memory()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-01-28 16:21:51,146][oumi][rank0][pid:865958][MainThread][INFO]][models.py:185] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...\n",
      "[2025-01-28 16:21:51,147][oumi][rank0][pid:865958][MainThread][INFO]][models.py:255] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForVision2Seq'> to instantiate model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n",
      "WARNING:accelerate.utils.modeling:The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37d03be7b36c4421be06e0e39848e96a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-01-28 16:21:56,531][oumi][rank0][pid:865958][MainThread][INFO]][models.py:236] Loading PEFT adapter from: vision_language_tutorial ...\n",
      "[2025-01-28 16:21:57,627][oumi][rank0][pid:865958][MainThread][INFO]][models.py:428] Using the chat template 'llama3-instruct' specified in model config!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[USER: <IMAGE_BINARY> | Is this boy a good pitcher?\n",
       " ASSISTANT: yes]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = InferenceConfig.from_yaml(str(Path(tutorial_dir) / \"trained_infer.yaml\"))\n",
    "inference_engine = NativeTextInferenceEngine(config.model)\n",
    "\n",
    "example = dataset.conversation(1)\n",
    "example = Conversation(messages=example.filter_messages(role=Role.USER))\n",
    "inference_engine.infer([example], inference_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-01-28 16:22:01,541][oumi][rank0][pid:865958][MainThread][INFO]][native_text_inference_engine.py:111] Setting EOS token id to `128009`\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[USER: <IMAGE_BINARY> | Is this boy a good pitcher?\n",
       " ASSISTANT: yes]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Or, if you want to test it with your own image/question pair:\n",
    "from oumi.core.types.conversation import ContentItem\n",
    "from oumi.utils.image_utils import load_image_png_bytes_from_path\n",
    "\n",
    "your_image_path = f\"{tutorial_dir}/example_1.png\"  # Replace with your image path!\n",
    "image_bytes = load_image_png_bytes_from_path(your_image_path)\n",
    "\n",
    "conversation = Conversation(\n",
    "    messages=[\n",
    "        Message(\n",
    "            role=Role.USER,\n",
    "            content=[\n",
    "                ContentItem(type=Type.IMAGE_BINARY, binary=image_bytes),\n",
    "                # Replace the question below with your own question!\n",
    "                ContentItem(type=Type.TEXT, content=\"Is this boy a good pitcher?\"),\n",
    "            ],\n",
    "        )\n",
    "    ]\n",
    ")\n",
    "\n",
    "inference_engine.infer([conversation], config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🧭 What's Next?\n",
    "\n",
    "Congrats on finishing this notebook! Feel free to check out our other [notebooks](https://github.com/oumi-ai/oumi/tree/main/notebooks) in the [Oumi GitHub](https://github.com/oumi-ai/oumi), and give us a star! You can also join the Oumi community over on [Discord](https://discord.gg/oumi).\n",
    "\n",
    "📰 Want to keep up with news from Oumi? Subscribe to our [Substack](https://blog.oumi.ai/) and [Youtube](https://www.youtube.com/@Oumi_AI)!\n",
    "\n",
    "⚡ Interested in building custom AI in hours, not months? Apply to get [early access](https://oumi.ai/contact?utm_source=oumi_oss_tutorial_vlms) to the Oumi Platform, or [chat with us](https://oumi.ai/book?utm_source=oumi_oss_tutorial_vlms) to learn more!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
