{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/oumi-ai/oumi/blob/main/notebooks/Oumi - Training CNN on Custom Dataset.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S7bYaH10SgtN"
   },
   "source": [
    "# Training CNN on Custom Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QkhNGqE1SgtP"
   },
   "source": [
    "Oumi is not limited to LLMs. This example shows how to train a simple ConvNet classifier on a custom dataset containing binary data in Numpy `.npz` file. The dataset is created from the classic MNIST dataset (hand-written digits classification)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fHDr11SqSgtP"
   },
   "source": [
    "## 📋 Prerequisites\n",
    "\n",
    "❗**NOTICE:** We recommend running this notebook on a GPU. If running on Google Colab, you can use the free T4 GPU runtime (Colab Menu: `Runtime` -> `Change runtime type`).\n",
    "\n",
    "First, let's install Oumi. You can find more detailed instructions [here](https://oumi.ai/docs/en/latest/get_started/installation.html). Here, we include Oumi's GPU dependencies.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install oumi[gpu]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fHDr11SqSgtP"
   },
   "source": [
    "## Environment Setup: Common Imports and Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torchvision\n",
    "\n",
    "tutorial_dir = \"cnn_mnist_tutorial\"\n",
    "\n",
    "Path(tutorial_dir).mkdir(parents=True, exist_ok=True)\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # Disable warnings from HF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fHDr11SqSgtP"
   },
   "source": [
    "# Data\n",
    "## Data Preparation\n",
    "First, let's convert MNIST dataset to `.npz` archive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "JPmWKRVCSgtP"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved 70000 examples to '/home/user/oumi/notebooks/cnn_mnist_tutorial/mnist.npz'!\n"
     ]
    }
   ],
   "source": [
    "images = []\n",
    "labels = []\n",
    "splits = []\n",
    "for train_split in (False, True):\n",
    "    mnist_dataset = torchvision.datasets.MNIST(\n",
    "        root=Path(\"/tmp/mnist_data\"),\n",
    "        train=train_split,\n",
    "        download=True,\n",
    "    )\n",
    "    num_examples = len(mnist_dataset)\n",
    "    images.extend(\n",
    "        [np.asarray(mnist_dataset.data[i], dtype=np.uint8) for i in range(num_examples)]\n",
    "    )\n",
    "    labels.extend([int(mnist_dataset.targets[i]) for i in range(num_examples)])\n",
    "    splits.extend([(\"train\" if train_split else \"test\")] * num_examples)\n",
    "\n",
    "npz_filename = (Path(tutorial_dir) / \"mnist.npz\").absolute()\n",
    "\n",
    "# Normalize and convert [N,W,H] to [N,C,W,H] by adding dummy C=1 (PyTorch convention).\n",
    "images = np.expand_dims((np.stack(images).astype(dtype=np.float32) / 255.0), axis=1)\n",
    "np.savez_compressed(\n",
    "    npz_filename, images=images, labels=np.stack(labels), split=np.stack(splits)\n",
    ")\n",
    "print(f\"Saved {len(labels)} examples to '{npz_filename}'!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define Oumi custom dataset that can load MNIST data from `.npz` archive. For more details, refer to: https://oumi.ai/docs/en/latest/resources/datasets/datasets.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from typing_extensions import override\n",
    "\n",
    "from oumi.core.datasets import BaseMapDataset\n",
    "from oumi.core.registry import register_dataset\n",
    "\n",
    "\n",
    "@register_dataset(\"npz_file\")\n",
    "class NpzDataset(BaseMapDataset):\n",
    "    \"\"\"Loads dataset from Numpy .npz archive.\"\"\"\n",
    "\n",
    "    default_dataset = \"custom\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        *,\n",
    "        dataset_name: str | None = None,\n",
    "        dataset_path: str | Path | None = None,\n",
    "        split: str | None = None,\n",
    "        npz_split_col: str | None = None,\n",
    "        npz_allow_pickle: bool = False,\n",
    "        **kwargs,\n",
    "    ) -> None:\n",
    "        \"\"\"Initializes a new instance of the NpzDataset class.\n",
    "\n",
    "        Args:\n",
    "            dataset_name: Dataset name.\n",
    "            dataset_path: Path to .npz file.\n",
    "            split: Dataset split.\n",
    "            npz_split_col: Name of '.npz' array containing dataset split info.\n",
    "                If unspecified, then the name \"split\" is assumed by default.\n",
    "            npz_allow_pickle: Whether pickle is allowed when loading data\n",
    "                from the npz archive.\n",
    "            **kwargs: Additional arguments to pass to the parent class.\n",
    "\n",
    "        Raises:\n",
    "            ValueError: If dataset_path is not provided, or\n",
    "                if .npz file contains data in unexpected format.\n",
    "        \"\"\"\n",
    "        if not dataset_path:\n",
    "            raise ValueError(\"`dataset_path` must be provided\")\n",
    "        super().__init__(\n",
    "            dataset_name=dataset_name,\n",
    "            dataset_path=(str(dataset_path) if dataset_path is not None else None),\n",
    "            split=split,\n",
    "            **kwargs,\n",
    "        )\n",
    "        self._npz_allow_pickle = npz_allow_pickle\n",
    "        self._npz_split_col = npz_split_col\n",
    "\n",
    "        dataset_path = Path(dataset_path)\n",
    "        if not dataset_path.is_file():\n",
    "            raise ValueError(f\"Path is not a file! '{dataset_path}'\")\n",
    "        elif dataset_path.suffix.lower() != \".npz\":\n",
    "            raise ValueError(f\"File extension is not '.npz'! '{dataset_path}'\")\n",
    "\n",
    "        self._data = self._load_data()\n",
    "\n",
    "    @staticmethod\n",
    "    def _to_list(x: np.ndarray) -> list:\n",
    "        # `pd.DataFrame` expects Python lists for columns\n",
    "        # (elements can still be `ndarray`)\n",
    "        if len(x.shape) > 1:\n",
    "            return [x[i, ...] for i in range(x.shape[0])]\n",
    "        return x.tolist()\n",
    "\n",
    "    @override\n",
    "    def _load_data(self) -> pd.DataFrame:\n",
    "        data_dict = {}\n",
    "        if not self.dataset_path:\n",
    "            raise ValueError(\"dataset_path is empty!\")\n",
    "        with np.load(self.dataset_path, allow_pickle=self._npz_allow_pickle) as npzfile:\n",
    "            feature_names = list(sorted(npzfile.files))\n",
    "            if len(feature_names) == 0:\n",
    "                raise ValueError(\n",
    "                    f\"'.npz' archive contains no data! '{self.dataset_path}'\"\n",
    "                )\n",
    "            num_examples = None\n",
    "            for feature_name in feature_names:\n",
    "                col_data = npzfile[feature_name]\n",
    "                assert isinstance(col_data, np.ndarray)\n",
    "                if num_examples is None:\n",
    "                    num_examples = col_data.shape[0]\n",
    "                elif num_examples != col_data.shape[0]:\n",
    "                    raise ValueError(\n",
    "                        \"Inconsistent number of examples for features \"\n",
    "                        f\"'{feature_name}' and '{feature_names[0]}': \"\n",
    "                        f\"{col_data.shape[0]} vs {num_examples}!\"\n",
    "                    )\n",
    "                data_dict[feature_name] = self._to_list(col_data)\n",
    "\n",
    "        dataframe: pd.DataFrame = pd.DataFrame(data_dict)\n",
    "\n",
    "        split_feature_name = (self._npz_split_col or \"split\") if self.split else None\n",
    "        if split_feature_name:\n",
    "            if split_feature_name not in dataframe:\n",
    "                raise ValueError(\n",
    "                    f\"'.npz' doesn't contain data split info: '{split_feature_name}'!\"\n",
    "                )\n",
    "            dataframe = pd.DataFrame(\n",
    "                dataframe[dataframe[split_feature_name] == self.split].drop(\n",
    "                    split_feature_name, axis=1\n",
    "                ),\n",
    "                copy=True,\n",
    "            )\n",
    "        return dataframe\n",
    "\n",
    "    @override\n",
    "    def transform(self, sample: pd.Series) -> dict:\n",
    "        \"\"\"Preprocesses the inputs in the given sample.\"\"\"\n",
    "        return sample.to_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m8PTJuc4SgtQ"
   },
   "source": [
    "# Training a Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b2_HamuySgtQ"
   },
   "source": [
    "Oumi provides the sample `CnnClassfier` model [[source](https://github.com/oumi-ai/oumi/blob/main/src/oumi/models/cnn_classifier.py)]. Let's use it to train a classifier for MNIST hand-written digits.\n",
    "\n",
    "Oumi uses [training configuration files](https://oumi.ai/docs/en/latest/api/oumi.core.configs.html#oumi.core.configs.TrainingConfig) to specify training parameters. We've already created a training config for `CnnClassfier`--let's give it a try!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "l2SQ9fZiSgtQ"
   },
   "outputs": [],
   "source": [
    "yaml_content = f\"\"\"\n",
    "model:\n",
    "  model_name: \"CnnClassifier\"\n",
    "  torch_dtype_str: \"float32\"\n",
    "  load_pretrained_weights: False\n",
    "  model_kwargs:\n",
    "      image_width: 28   # MNIST images are 28x28 single channel\n",
    "      image_height: 28\n",
    "      in_channels: 1\n",
    "      output_dim: 10    # Number of output classes: 10 digits\n",
    "\n",
    "data:\n",
    "  train:\n",
    "    use_torchdata: True\n",
    "    datasets:\n",
    "      - dataset_name: \"npz_file\" # Custom dataset defined above for .npz archives\n",
    "        dataset_path: \"{npz_filename}\"\n",
    "        split: \"train\"\n",
    "\n",
    "training:\n",
    "  trainer_type: \"OUMI\"  # For non-transformers, use \"OUMI\" trainer\n",
    "  per_device_train_batch_size: 64\n",
    "  max_steps: 2000\n",
    "  logging_steps: 500\n",
    "  run_name: \"mnist_cnn_classifier\"\n",
    "  output_dir: \"{tutorial_dir}/output\"\n",
    "\"\"\"\n",
    "\n",
    "with open(f\"{tutorial_dir}/train.yaml\", \"w\") as f:\n",
    "    f.write(yaml_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2GpQDGG5SgtQ"
   },
   "outputs": [],
   "source": [
    "from oumi.core.configs import TrainingConfig\n",
    "from oumi.train import train\n",
    "\n",
    "config = TrainingConfig.from_yaml(str(Path(tutorial_dir) / \"train.yaml\"))\n",
    "\n",
    "train(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🧭 What's Next?\n",
    "\n",
    "Congratulations, you've trained your first CNN using data from a custom dataset (`numpy` arrays) with Oumi! Feel free to check out our other [notebooks](https://github.com/oumi-ai/oumi/tree/main/notebooks) in the [Oumi GitHub](https://github.com/oumi-ai/oumi), and give us a star! You can also join the Oumi community over on [Discord](https://discord.gg/oumi).\n",
    "\n",
    "📰 Want to keep up with news from Oumi? Subscribe to our [Substack](https://blog.oumi.ai/) and [Youtube](https://www.youtube.com/@Oumi_AI)!\n",
    "\n",
    "⚡ Interested in building custom AI in hours, not months? Apply to get [early access](https://oumi.ai/contact?utm_source=oumi_oss_tutorial_cnn_custom_dataset) to the Oumi Platform, or [chat with us](https://oumi.ai/book?utm_source=oumi_oss_tutorial_cnn_custom_dataset) to learn more!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
