{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S7bYaH10SgtN"
   },
   "source": [
    "# OpenEnv GRPO with trl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QkhNGqE1SgtP"
   },
   "source": [
    "In this tutorial notebook, we're going to use Oumi to train an agentic model on an [OpenEnv](https://github.com/meta-pytorch/OpenEnv) Echo reinforcement learning (RL) environment with the GRPO algorithm. To achieve this, we use the trl library by Hugging Face with a custom rollout function to interact with the vLLM server and OpenEnv environment. This notebook is derived from trl's [Echo environment example](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fHDr11SqSgtP"
   },
   "source": [
    "# 📋 Prerequisites\n",
    "\n",
    "❗**NOTICE:** This notebook needs to be running on a machine with at least two GPUs.\n",
    "\n",
    "## Oumi Installation\n",
    "\n",
    "First, let's install the latest versions of Oumi and OpenEnv. You can find more detailed instructions [here](https://oumi.ai/docs/en/latest/get_started/installation.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install uv && uv pip install \"oumi[gpu] @ git+https://github.com/oumi-ai/oumi.git\"\n",
    "!uv pip install git+https://github.com/meta-pytorch/OpenEnv.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "JPmWKRVCSgtP"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "tutorial_dir = \"openenv_tutorial\"\n",
    "\n",
    "Path(tutorial_dir).mkdir(parents=True, exist_ok=True)\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # Disable warnings from HF."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start OpenEnv and vLLM servers\n",
    "\n",
    "We need to run 2 servers in addition to the trl trainer. The OpenEnv server receives actions from the LLM and returns the updated state and reward. The vLLM server is used for inference, and updates it weights over training with the updated model weights from the trainer. We start these with separate subprocesses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting openenv_tutorial/start_openenv_server.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile $tutorial_dir/start_openenv_server.py\n",
    "\n",
    "import os\n",
    "import subprocess\n",
    "import sys\n",
    "import threading\n",
    "import time\n",
    "from pathlib import Path\n",
    "\n",
    "import requests\n",
    "\n",
    "\n",
    "def stream_output(pipe, prefix=\"\"):\n",
    "    \"\"\"Stream output lines from subprocess pipe to stdout.\"\"\"\n",
    "    for line in iter(pipe.readline, \"\"):\n",
    "        print(f\"{prefix}{line}\", end=\"\")\n",
    "    pipe.close()\n",
    "\n",
    "\n",
    "print(\"⚡ Starting FastAPI server for Echo Environment...\")\n",
    "\n",
    "work_dir = str(Path.cwd().parent.absolute())\n",
    "\n",
    "server_process = subprocess.Popen(\n",
    "    [\n",
    "        sys.executable,\n",
    "        \"-m\",\n",
    "        \"uvicorn\",\n",
    "        \"envs.echo_env.server.app:app\",\n",
    "        \"--host\",\n",
    "        \"0.0.0.0\",\n",
    "        \"--port\",\n",
    "        \"8001\",\n",
    "    ],\n",
    "    env={**os.environ, \"PYTHONPATH\": f\"{work_dir}/src\"},\n",
    "    stdout=subprocess.PIPE,\n",
    "    stderr=subprocess.PIPE,\n",
    "    text=True,\n",
    "    cwd=work_dir,\n",
    ")\n",
    "\n",
    "# Start background threads to stream errors\n",
    "threading.Thread(\n",
    "    target=stream_output, args=(server_process.stderr, \"🔥 [stderr] \"), daemon=True\n",
    ").start()\n",
    "\n",
    "print(\"⏳ Waiting for server to start...\")\n",
    "time.sleep(5)\n",
    "\n",
    "try:\n",
    "    response = requests.get(\"http://0.0.0.0:8001/health\", timeout=2)\n",
    "    print(\"\\n✅ Echo Environment server is running!\")\n",
    "except Exception as e:\n",
    "    print(f\"\\n❌ Server failed to start: {e}\")\n",
    "    print(\"\\n📋 Checking error output...\")\n",
    "    server_process.poll()\n",
    "    if server_process.stderr:\n",
    "        stderr = server_process.stderr.read()\n",
    "        if stderr:\n",
    "            print(stderr)\n",
    "    raise\n",
    "\n",
    "try:\n",
    "    input(\"Press Enter to exit...\\n\")\n",
    "finally:\n",
    "    print(\"🛑 Stopping server...\")\n",
    "    server_process.terminate()\n",
    "    server_process.wait()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Servers started. PIDs: 3787594 3787595\n"
     ]
    }
   ],
   "source": [
    "import subprocess\n",
    "\n",
    "# Start both servers in the background\n",
    "server1 = subprocess.Popen(\n",
    "    [\n",
    "        \"bash\",\n",
    "        \"-c\",\n",
    "        (\n",
    "            \"CUDA_VISIBLE_DEVICES=0 trl vllm-serve \"\n",
    "            \"--model Qwen/Qwen2.5-0.5B-Instruct \"\n",
    "            \"--log-level warning \"\n",
    "            \"--host 0.0.0.0 --port 8000\"\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "server2 = subprocess.Popen([\"python\", f\"{tutorial_dir}/start_openenv_server.py\"])\n",
    "\n",
    "print(\"Servers started. PIDs:\", server1.pid, server2.pid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "⚡ Starting FastAPI server for Echo Environment...\n",
      "⏳ Waiting for server to start...\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0f93950>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "🔥 [stderr] INFO:     Started server process [3787596]\n",
      "🔥 [stderr] INFO:     Waiting for application startup.\n",
      "🔥 [stderr] INFO:     Application startup complete.\n",
      "🔥 [stderr] INFO:     Uvicorn running on http://0.0.0.0:8001 (Press CTRL+C to quit)\n",
      "\n",
      "✅ Echo Environment server is running!\n",
      "Press Enter to exit...\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0f9df50>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0fa4250>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "INFO 10-31 17:31:32 [__init__.py:216] Automatically detected platform cuda.\n",
      "INFO 10-31 17:31:33 [utils.py:328] non-default args: {'disable_log_stats': True, 'worker_extension_cls': 'trl.scripts.vllm_serve.WeightSyncWorkerExtension', 'model_impl': 'vllm', 'model': 'Qwen/Qwen2.5-0.5B-Instruct'}\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0fa6810>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "INFO 10-31 17:31:41 [__init__.py:742] Resolved architecture: Qwen2ForCausalLM\n",
      "INFO 10-31 17:31:41 [__init__.py:1815] Using max model len 32768\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`torch_dtype` is deprecated! Use `dtype` instead!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 10-31 17:31:42 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=16384.\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0fad190>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0fa44d0>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "INFO 10-31 17:31:51 [__init__.py:216] Automatically detected platform cuda.\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:52 [core.py:654] Waiting for init message from front-end.\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:52 [core.py:76] Initializing a V1 LLM engine (v0.10.2) with config: model='Qwen/Qwen2.5-0.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-0.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen2.5-0.5B-Instruct, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={\"level\":3,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"\",\"custom_ops\":[],\"splitting_ops\":[\"vllm.unified_attention\",\"vllm.unified_attention_with_output\",\"vllm.mamba_mixer2\",\"vllm.mamba_mixer\",\"vllm.short_conv\",\"vllm.linear_attention\",\"vllm.plamo2_mamba_mixer\",\"vllm.gdn_attention\"],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"cudagraph_mode\":1,\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"pass_config\":{},\"max_capture_size\":512,\"local_cache_dir\":null}\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0fa5090>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:52 [worker_base.py:595] Injected <class 'trl.scripts.vllm_serve.WeightSyncWorkerExtension'> into <class 'vllm.v1.worker.gpu_worker.Worker'> for extended collective_rpc calls ['close_communicator', 'init_communicator', 'update_named_param']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[W1031 17:31:53.898867831 ProcessGroupNCCL.cpp:981] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0\n",
      "[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0\n",
      "[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0\n",
      "[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0\n",
      "[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0\n",
      "[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [parallel_state.py:1165] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m WARNING 10-31 17:31:53 [topk_topp_sampler.py:69] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [gpu_model_runner.py:2338] Starting to load model Qwen/Qwen2.5-0.5B-Instruct...\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [gpu_model_runner.py:2370] Loading model from scratch...\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [cuda.py:362] Using Flash Attention backend on V1 engine.\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [weight_utils.py:348] Using model weights format ['*.safetensors']\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [weight_utils.py:406] No model.safetensors.index.json found in remote.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.39it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.39it/s]\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:53 [default_loader.py:268] Loading weights took 0.18 seconds\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:54 [gpu_model_runner.py:2392] Model loading took 0.9266 GiB and 0.487034 seconds\n",
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0f9d690>: Failed to establish a new connection: [Errno 111] Connection refused'))\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:58 [backends.py:539] Using cache directory: /home/wizeng/.cache/vllm/torch_compile_cache/5d31f4c583/rank_0_0/backbone for vLLM's torch.compile\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:58 [backends.py:550] Dynamo bytecode transform time: 3.53 s\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:59 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 1.361 s\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:31:59 [monitor.py:34] torch.compile takes 3.53 s in total\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:00 [gpu_worker.py:298] Available KV cache memory: 64.71 GiB\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:00 [kv_cache_utils.py:864] GPU KV cache size: 5,654,128 tokens\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:00 [kv_cache_utils.py:868] Maximum concurrency for 32,768 tokens per request: 172.55x\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):  81%|████████  | 54/67 [00:01<00:00, 38.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "❌ Server not ready: HTTPConnectionPool(host='0.0.0.0', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7d30b0faefd0>: Failed to establish a new connection: [Errno 111] Connection refused'))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 67/67 [00:01<00:00, 40.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:02 [gpu_model_runner.py:3118] Graph capturing finished in 2 secs, took 0.50 GiB\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:02 [gpu_worker.py:391] Free memory on device (78.59/79.19 GiB) on startup. Desired GPU memory utilization is (0.9, 71.27 GiB). Actual usage is 0.93 GiB for weight, 5.57 GiB for peak activation, 0.07 GiB for non-torch memory, and 0.5 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=68779733708` to fit into requested memory, or `--kv-cache-memory=76635612672` to fully utilize gpu memory. Current kv cache memory in use is 69478085324 bytes.\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:02 [core.py:218] init engine (profile, create kv cache, warmup model) took 8.68 seconds\n",
      "INFO 10-31 17:32:04 [llm.py:295] Supported_tasks: ['generate']\n",
      "INFO 10-31 17:32:04 [__init__.py:36] No IOProcessor plugins requested by the model\n",
      "✅ vLLM server is healthy!\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "import requests\n",
    "\n",
    "URL = \"http://0.0.0.0:8000/health\"\n",
    "\n",
    "\n",
    "def check_vllm_health():\n",
    "    \"\"\"Checks if the vLLM server is healthy.\"\"\"\n",
    "    try:\n",
    "        response = requests.get(URL, timeout=3)\n",
    "        if response.status_code == 200:\n",
    "            print(\"✅ vLLM server is healthy!\")\n",
    "            return True\n",
    "        else:\n",
    "            print(f\"⚠️ Server responded with {response.status_code}\")\n",
    "    except requests.RequestException as e:\n",
    "        print(f\"❌ Server not ready: {e}\")\n",
    "    return False\n",
    "\n",
    "\n",
    "max_retries = 24\n",
    "for attempt in range(1, max_retries + 1):\n",
    "    if check_vllm_health():\n",
    "        break\n",
    "    time.sleep(5)\n",
    "else:\n",
    "    print(f\"❌ Failed to start vLLM server after {max_retries} attempts.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model!\n",
    "\n",
    "By providing a custom rollout function to interact with the OpenEnv and vLLM servers, we can use trl to do agentic GRPO training. We also need to provide a reward function that processes the reward value output by the environment.\n",
    "\n",
    "The following script defines the custom rollout and reward functions and runs the trainer. We run it as a subprocess so that we can set `CUDA_VISIBLE_DEVICES` to not conflict with the vLLM server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting openenv_tutorial/train.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile $tutorial_dir/train.py\n",
    "\n",
    "import requests\n",
    "from envs.echo_env import EchoEnv\n",
    "from envs.echo_env.models import EchoAction\n",
    "\n",
    "from oumi.core.configs import TrainingConfig\n",
    "from oumi.core.registry import RegistryType, register\n",
    "from oumi.train import train\n",
    "\n",
    "\n",
    "@register(\"env_reward\", RegistryType.REWARD_FUNCTION)\n",
    "def reward_from_env(completions, **kwargs):\n",
    "    \"\"\"Reward function that uses the environment reward.\"\"\"\n",
    "    # Extract environment rewards from kwargs (propagated via extra_fields)\n",
    "    env_rewards = kwargs.get(\"env_reward\", [])\n",
    "    if env_rewards:\n",
    "        return [float(reward) for reward in env_rewards]\n",
    "    else:\n",
    "        # Fallback if env_reward is not available\n",
    "        return [0.0] * len(completions)\n",
    "\n",
    "\n",
    "@register(\"echo_env_vllm_rollout\", RegistryType.ROLLOUT_FUNCTION)\n",
    "def echo_env_vllm_rollout(\n",
    "    prompts: list[str], args, processing_class\n",
    ") -> dict[str, list]:\n",
    "    \"\"\"Custom rollout function that generates completions via vLLM server and computes environment rewards.\n",
    "\n",
    "    Args:\n",
    "        prompts: List of prompts to generate from\n",
    "        args: GRPOConfig containing all sampling parameters\n",
    "        processing_class: Tokenizer/processor for decoding completions\n",
    "\n",
    "    Returns:\n",
    "        Dict containing prompt_ids, completion_ids, logprobs, and env_reward\n",
    "    \"\"\"  # noqa: E501\n",
    "    # 1. Generate completions via vLLM inference server (running on port 8000)\n",
    "    payload = {\n",
    "        \"prompts\": prompts,\n",
    "        \"n\": args.num_generations,\n",
    "        \"temperature\": args.temperature,\n",
    "        \"top_p\": args.top_p,\n",
    "        \"top_k\": -1 if args.top_k is None else args.top_k,\n",
    "        \"min_p\": 0.0 if args.min_p is None else args.min_p,\n",
    "        \"max_tokens\": args.max_completion_length,\n",
    "        \"repetition_penalty\": args.repetition_penalty,\n",
    "    }\n",
    "    response = requests.post(\"http://0.0.0.0:8000/generate/\", json=payload)\n",
    "\n",
    "    if response.status_code != 200:\n",
    "        print(f\"Error response: {response.text}\")\n",
    "\n",
    "    response.raise_for_status()\n",
    "    result = response.json()\n",
    "\n",
    "    completions_text = processing_class.batch_decode(\n",
    "        result[\"completion_ids\"], skip_special_tokens=True\n",
    "    )\n",
    "\n",
    "    # 2. Step through the environment to get rewards\n",
    "    client = EchoEnv(base_url=\"http://0.0.0.0:8001\")\n",
    "    env_result = client.reset()\n",
    "    env_rewards = []\n",
    "    for msg in completions_text:\n",
    "        env_result = client.step(EchoAction(message=msg))\n",
    "        env_rewards.append(env_result.reward)\n",
    "\n",
    "    # 3. Add environment rewards as extra field\n",
    "    result[\"env_reward\"] = env_rewards\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "config = TrainingConfig.from_yaml(\"openenv_tutorial/grpo_train.yaml\")\n",
    "train(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we define the YAML training config, and kick off training!\n",
    "\n",
    "To enable logging to Weights and Biases, uncomment the relevant line in the config below, and make sure to [set up wandb](https://oumi.ai/docs/en/latest/development/dev_setup.html#optional-set-up-weights-and-biases) on your machine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting openenv_tutorial/grpo_train.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile $tutorial_dir/grpo_train.yaml\n",
    "\n",
    "model:\n",
    "  model_name: \"Qwen/Qwen2-0.5B-Instruct\"\n",
    "  model_max_length: 2048\n",
    "  torch_dtype_str: \"bfloat16\"\n",
    "  attn_implementation: \"sdpa\"\n",
    "\n",
    "data:\n",
    "  train:\n",
    "    datasets:\n",
    "      - dataset_name: \"trl-lib/ultrafeedback-prompt\"\n",
    "        split: \"train\"\n",
    "        sample_count: 100\n",
    "\n",
    "training:\n",
    "  trainer_type: \"TRL_GRPO\"\n",
    "  per_device_train_batch_size: 8\n",
    "  gradient_accumulation_steps: 4\n",
    "\n",
    "  reward_functions: [\"env_reward\"]\n",
    "\n",
    "  ddp_find_unused_parameters: False\n",
    "  optimizer: \"adamw_torch_fused\"\n",
    "\n",
    "  grpo:\n",
    "    use_vllm: True\n",
    "    rollout_function: \"echo_env_vllm_rollout\"\n",
    "    max_completion_length: 2048\n",
    "\n",
    "  dataloader_num_workers: \"auto\"\n",
    "  dataloader_prefetch_factor: 32\n",
    "\n",
    "  num_train_epochs: 1\n",
    "  logging_steps: 1\n",
    "  log_model_summary: False\n",
    "  output_dir: \"openenv_tutorial/echo_grpo\"\n",
    "  # Uncomment to enable wandb logging\n",
    "  # enable_wandb: True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-10-31 17:32:17,665][oumi][rank0][pid:3788270][MainThread][WARNING]][training_config.py:149] Ignored model.model_max_length=2048 parameter for trainer TrainerType.TRL_GRPO.\n",
      "[2025-10-31 17:32:17,666][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:117] Creating training.output_dir: openenv_tutorial/echo_grpo...\n",
      "[2025-10-31 17:32:17,668][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:119] Created training.output_dir absolute path: /home/wizeng/repos/oumi/notebooks/openenv_tutorial/echo_grpo\n",
      "[2025-10-31 17:32:17,669][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:117] Creating training.telemetry_dir: openenv_tutorial/echo_grpo/telemetry...\n",
      "[2025-10-31 17:32:17,672][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:119] Created training.telemetry_dir absolute path: /home/wizeng/repos/oumi/notebooks/openenv_tutorial/echo_grpo/telemetry\n",
      "[2025-10-31 17:32:17,675][oumi][rank0][pid:3788270][MainThread][INFO]][torch_utils.py:80] Torch version: 2.8.0+cu128. NumPy version: 1.26.4\n",
      "[2025-10-31 17:32:17,675][oumi][rank0][pid:3788270][MainThread][INFO]][torch_utils.py:88] CUDA version: 12.8 \n",
      "[2025-10-31 17:32:17,676][oumi][rank0][pid:3788270][MainThread][INFO]][torch_utils.py:91] CuDNN version: 90.8.0\n",
      "[2025-10-31 17:32:17,825][oumi][rank0][pid:3788270][MainThread][INFO]][torch_utils.py:124] CPU cores: 208 CUDA devices: 1\n",
      "device(0)='NVIDIA H100 80GB HBM3' Capability: (9, 0) Memory: [Total: 79.19GiB Free: 78.68GiB Allocated: 0.0GiB Cached: 0.0GiB]\n",
      "[2025-10-31 17:32:17,831][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:154] Oumi version: 0.4.3.dev7+gee26267d5.d20251030\n",
      "[2025-10-31 17:32:17,838][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:156] Git revision hash: 568abe14e6e819ef5e844360cb054bc34b5911eb\n",
      "[2025-10-31 17:32:17,881][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:157] Git tag: None\n",
      "[2025-10-31 17:32:17,885][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:165] Resolved 'training.dataloader_num_workers=auto' to 'training.dataloader_num_workers=2'\n",
      "[2025-10-31 17:32:17,911][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:313] Training config saved to openenv_tutorial/echo_grpo/telemetry/training_config.yaml\n",
      "[2025-10-31 17:32:18,276][oumi][rank0][pid:3788270][MainThread][INFO]][models.py:544] Using the model's built-in chat template for model 'Qwen/Qwen2-0.5B-Instruct'.\n",
      "[2025-10-31 17:32:18,835][oumi][rank0][pid:3788270][MainThread][INFO]][models.py:260] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...\n",
      "[2025-10-31 17:32:18,861][oumi][rank0][pid:3788270][MainThread][INFO]][models.py:336] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'> to instantiate model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`torch_dtype` is deprecated! Use `dtype` instead!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-10-31 17:32:19,171][oumi][rank0][pid:3788270][MainThread][INFO]][torch_utils.py:288] \n",
      "Model Parameters Summary:\n",
      "🔢 Total     parameters: 494,032,768\n",
      "🔗 Embedding parameters: 136,134,656\n",
      "🎯 Trainable parameters: 494,032,768\n",
      "🔒 Frozen    parameters: 0 (0.00%)\n",
      "\n",
      "INFO 10-31 17:32:19 [__init__.py:216] Automatically detected platform cuda.\n",
      "[2025-10-31 17:32:19,942][oumi][rank0][pid:3788270][MainThread][INFO]][torch_profiler_utils.py:164] PROF: Torch Profiler disabled!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wizeng/repos/oumi/src/oumi/builders/training.py:70: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n",
      "  trainer = HuggingFaceTrainer(cls(*args, **kwargs, args=hf_args), processor)\n",
      "The model is already on multiple devices. Skipping the move to device specified in `args`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: The cache directory for DeepSpeed Triton autotune, /home/wizeng/.triton/autotune, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path.\n",
      "INFO 10-31 17:32:20 [__init__.py:1433] Found nccl from library libnccl.so.2\n",
      "INFO 10-31 17:32:20 [pynccl.py:70] vLLM is using nccl==2.27.3\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:20 [__init__.py:1433] Found nccl from library libnccl.so.2\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:20 [pynccl.py:70] vLLM is using nccl==2.27.3\n",
      "[2025-10-31 17:32:21,330][oumi][rank0][pid:3788270][MainThread][INFO]][device_utils.py:343] GPU Metrics Before Training: GPU runtime info: NVidiaGpuRuntimeInfo(device_index=0, device_count=2, used_memory_mb=75593.0, temperature=33, fan_speed=None, fan_speeds=None, power_usage_watts=123.946, power_limit_watts=700.0, gpu_utilization=0, memory_utilization=0, performance_state=0, clock_speed_graphics=1980, clock_speed_sm=1980, clock_speed_memory=2619).\n",
      "[2025-10-31 17:32:21,330][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:553] Training init time: 3.665s\n",
      "[2025-10-31 17:32:21,330][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:554] Starting training... (TrainerType.TRL_GRPO, transformers: 4.57.1)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:22 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1391.38it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:01<00:00, 27.21it/s, est. speed input: 1054.58 toks/s, output: 3325.28 toks/s]\n",
      "  4%|▍         | 1/25 [00:02<01:02,  2.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.3104, 'grad_norm': 5.6875, 'learning_rate': 5e-05, 'num_tokens': 5150.0, 'completions/mean_length': 122.1875, 'completions/min_length': 22.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.15625, 'completions/mean_terminated_length': 97.40740966796875, 'completions/min_terminated_length': 22.0, 'completions/max_terminated_length': 244.0, 'rewards/reward_from_env/mean': 60.24374771118164, 'rewards/reward_from_env/std': 46.80415344238281, 'reward': 60.24374771118164, 'reward_std': 20.179214477539062, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.10425987094640732, 'sampling/sampling_logp_difference/max': 1.4170303344726562, 'sampling/importance_sampling_ratio/min': 0.24243289232254028, 'sampling/importance_sampling_ratio/mean': 1.0240256786346436, 'sampling/importance_sampling_ratio/max': 1.5776448249816895, 'entropy': 1.3606750071048737, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.04}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:24 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1356.72it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.80it/s, est. speed input: 3143.41 toks/s, output: 9801.17 toks/s]\n",
      "  8%|▊         | 2/25 [00:04<00:48,  2.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.088, 'grad_norm': 3.96875, 'learning_rate': 4.8e-05, 'num_tokens': 14605.0, 'completions/mean_length': 223.71875, 'completions/min_length': 15.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.8125, 'completions/mean_terminated_length': 83.83333587646484, 'completions/min_terminated_length': 15.0, 'completions/max_terminated_length': 195.0, 'rewards/reward_from_env/mean': 103.9625015258789, 'rewards/reward_from_env/std': 39.819908142089844, 'reward': 103.9625015258789, 'reward_std': 14.376317977905273, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.0843813493847847, 'sampling/sampling_logp_difference/max': 1.4582233428955078, 'sampling/importance_sampling_ratio/min': 0.23264925181865692, 'sampling/importance_sampling_ratio/mean': 1.0169103145599365, 'sampling/importance_sampling_ratio/max': 1.5537441968917847, 'entropy': 1.0307188630104065, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.08}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:26 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1658.48it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.41it/s, est. speed input: 2062.37 toks/s, output: 11055.28 toks/s]\n",
      " 12%|█▏        | 3/25 [00:06<00:42,  1.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0031, 'grad_norm': 3.40625, 'learning_rate': 4.600000000000001e-05, 'num_tokens': 24273.0, 'completions/mean_length': 254.625, 'completions/min_length': 212.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.96875, 'completions/mean_terminated_length': 212.0, 'completions/min_terminated_length': 212.0, 'completions/max_terminated_length': 212.0, 'rewards/reward_from_env/mean': 121.828125, 'rewards/reward_from_env/std': 23.6931209564209, 'reward': 121.828125, 'reward_std': 15.864302635192871, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.07195073366165161, 'sampling/sampling_logp_difference/max': 1.2133426666259766, 'sampling/importance_sampling_ratio/min': 0.2972021698951721, 'sampling/importance_sampling_ratio/mean': 1.011971116065979, 'sampling/importance_sampling_ratio/max': 1.4239530563354492, 'entropy': 0.8396566212177277, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.0009819228725973517, 'clip_ratio/high_max': 0.0009819228725973517, 'clip_ratio/region_mean': 0.0011039931850973517, 'epoch': 0.12}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:28 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1631.87it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.91it/s, est. speed input: 3905.33 toks/s, output: 10986.34 toks/s]\n",
      " 16%|█▌        | 4/25 [00:07<00:37,  1.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0021, 'grad_norm': 3.390625, 'learning_rate': 4.4000000000000006e-05, 'num_tokens': 35377.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 124.26875305175781, 'rewards/reward_from_env/std': 22.02469825744629, 'reward': 124.26875305175781, 'reward_std': 16.020652770996094, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06552394479513168, 'sampling/sampling_logp_difference/max': 1.3148822784423828, 'sampling/importance_sampling_ratio/min': 0.26850593090057373, 'sampling/importance_sampling_ratio/mean': 1.0130901336669922, 'sampling/importance_sampling_ratio/max': 1.4909359216690063, 'entropy': 0.76953125, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.0006103515625, 'clip_ratio/high_max': 0.0006103515625, 'clip_ratio/region_mean': 0.000732421875, 'epoch': 0.16}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:29 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1804.78it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.95it/s, est. speed input: 1632.18 toks/s, output: 10995.63 toks/s]\n",
      " 20%|██        | 5/25 [00:09<00:34,  1.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0022, 'grad_norm': 3.625, 'learning_rate': 4.2e-05, 'num_tokens': 44785.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 124.9593734741211, 'rewards/reward_from_env/std': 20.39957046508789, 'reward': 124.9593734741211, 'reward_std': 16.528228759765625, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.07429219037294388, 'sampling/sampling_logp_difference/max': 1.242635726928711, 'sampling/importance_sampling_ratio/min': 0.2886224687099457, 'sampling/importance_sampling_ratio/mean': 1.0142146348953247, 'sampling/importance_sampling_ratio/max': 1.556097149848938, 'entropy': 0.8701171875, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.000732421875, 'clip_ratio/high_max': 0.000732421875, 'clip_ratio/region_mean': 0.0008544921875, 'epoch': 0.2}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:31 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1583.50it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 37.20it/s, est. speed input: 2278.77 toks/s, output: 9524.26 toks/s]\n",
      " 24%|██▍       | 6/25 [00:11<00:32,  1.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0007, 'grad_norm': 3.25, 'learning_rate': 4e-05, 'num_tokens': 54937.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 131.5187530517578, 'rewards/reward_from_env/std': 11.204102516174316, 'reward': 131.5187530517578, 'reward_std': 9.789403915405273, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06402796506881714, 'sampling/sampling_logp_difference/max': 1.597865104675293, 'sampling/importance_sampling_ratio/min': 0.20232799649238586, 'sampling/importance_sampling_ratio/mean': 1.0114407539367676, 'sampling/importance_sampling_ratio/max': 1.5237730741500854, 'entropy': 0.70703125, 'clip_ratio/low_mean': 0.0003662109375, 'clip_ratio/low_min': 0.0003662109375, 'clip_ratio/high_mean': 0.0006103515625, 'clip_ratio/high_max': 0.0006103515625, 'clip_ratio/region_mean': 0.0009765625, 'epoch': 0.24}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:33 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1428.70it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.04it/s, est. speed input: 3443.20 toks/s, output: 11018.15 toks/s]\n",
      " 28%|██▊       | 7/25 [00:12<00:30,  1.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0017, 'grad_norm': 3.046875, 'learning_rate': 3.8e-05, 'num_tokens': 65689.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 138.65936279296875, 'rewards/reward_from_env/std': 7.89592981338501, 'reward': 138.65936279296875, 'reward_std': 5.593094825744629, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06447682529687881, 'sampling/sampling_logp_difference/max': 1.380960464477539, 'sampling/importance_sampling_ratio/min': 0.2513370215892792, 'sampling/importance_sampling_ratio/mean': 1.0118381977081299, 'sampling/importance_sampling_ratio/max': 1.638052225112915, 'entropy': 0.70703125, 'clip_ratio/low_mean': 0.0003662109375, 'clip_ratio/low_min': 0.0003662109375, 'clip_ratio/high_mean': 0.000244140625, 'clip_ratio/high_max': 0.000244140625, 'clip_ratio/region_mean': 0.0006103515625, 'epoch': 0.28}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:34 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1692.28it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.56it/s, est. speed input: 3054.11 toks/s, output: 10896.80 toks/s]\n",
      " 32%|███▏      | 8/25 [00:14<00:28,  1.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0014, 'grad_norm': 3.078125, 'learning_rate': 3.6e-05, 'num_tokens': 76177.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 140.49374389648438, 'rewards/reward_from_env/std': 6.9356465339660645, 'reward': 140.49374389648438, 'reward_std': 6.394322395324707, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.04982556402683258, 'sampling/sampling_logp_difference/max': 1.3971309661865234, 'sampling/importance_sampling_ratio/min': 0.2473054677248001, 'sampling/importance_sampling_ratio/mean': 1.0109286308288574, 'sampling/importance_sampling_ratio/max': 1.5759567022323608, 'entropy': 0.5390625, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.000244140625, 'clip_ratio/high_max': 0.000244140625, 'clip_ratio/region_mean': 0.0003662109375, 'epoch': 0.32}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:36 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1776.87it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.24it/s, est. speed input: 1556.86 toks/s, output: 11070.94 toks/s]\n",
      " 36%|███▌      | 9/25 [00:16<00:27,  1.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0002, 'grad_norm': 2.78125, 'learning_rate': 3.4000000000000007e-05, 'num_tokens': 85521.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 146.29061889648438, 'rewards/reward_from_env/std': 11.264800071716309, 'reward': 146.29061889648438, 'reward_std': 8.717338562011719, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.05180970951914787, 'sampling/sampling_logp_difference/max': 1.4854364395141602, 'sampling/importance_sampling_ratio/min': 0.22640350461006165, 'sampling/importance_sampling_ratio/mean': 1.011709451675415, 'sampling/importance_sampling_ratio/max': 1.5055581331253052, 'entropy': 0.56396484375, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.000244140625, 'clip_ratio/high_max': 0.000244140625, 'clip_ratio/region_mean': 0.0003662109375, 'epoch': 0.36}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:38 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1612.42it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.11it/s, est. speed input: 2358.33 toks/s, output: 10676.88 toks/s]\n",
      " 40%|████      | 10/25 [00:17<00:25,  1.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0219, 'grad_norm': 2.859375, 'learning_rate': 3.2000000000000005e-05, 'num_tokens': 95426.0, 'completions/mean_length': 253.53125, 'completions/min_length': 177.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.96875, 'completions/mean_terminated_length': 177.0, 'completions/min_terminated_length': 177.0, 'completions/max_terminated_length': 177.0, 'rewards/reward_from_env/mean': 151.10000610351562, 'rewards/reward_from_env/std': 14.319016456604004, 'reward': 151.10000610351562, 'reward_std': 11.974573135375977, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06071959435939789, 'sampling/sampling_logp_difference/max': 1.414407730102539, 'sampling/importance_sampling_ratio/min': 0.24306952953338623, 'sampling/importance_sampling_ratio/mean': 1.013283371925354, 'sampling/importance_sampling_ratio/max': 2.0, 'entropy': 0.6508394777774811, 'clip_ratio/low_mean': 0.00024903831945266575, 'clip_ratio/low_min': 0.00024903831945266575, 'clip_ratio/high_mean': 0.00037110863195266575, 'clip_ratio/high_max': 0.00037110863195266575, 'clip_ratio/region_mean': 0.0006201469514053315, 'epoch': 0.4}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:39 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1577.25it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.01it/s, est. speed input: 2677.89 toks/s, output: 10048.74 toks/s]\n",
      " 44%|████▍     | 11/25 [00:19<00:23,  1.69s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.1194, 'grad_norm': 4.46875, 'learning_rate': 3e-05, 'num_tokens': 104893.0, 'completions/mean_length': 233.59375, 'completions/min_length': 7.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.90625, 'completions/mean_terminated_length': 17.0, 'completions/min_terminated_length': 7.0, 'completions/max_terminated_length': 37.0, 'rewards/reward_from_env/mean': 139.95623779296875, 'rewards/reward_from_env/std': 44.71029281616211, 'reward': 139.95623779296875, 'reward_std': 26.40770149230957, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.07985985279083252, 'sampling/sampling_logp_difference/max': 1.7661480903625488, 'sampling/importance_sampling_ratio/min': 0.17099036276340485, 'sampling/importance_sampling_ratio/mean': 1.0169445276260376, 'sampling/importance_sampling_ratio/max': 1.905361294746399, 'entropy': 0.8642259538173676, 'clip_ratio/low_mean': 0.00013896609016228467, 'clip_ratio/low_min': 0.00013896609016228467, 'clip_ratio/high_mean': 0.0015211951395031065, 'clip_ratio/high_max': 0.0015211951395031065, 'clip_ratio/region_mean': 0.0016601612442173064, 'epoch': 0.44}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:41 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1282.76it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.06it/s, est. speed input: 3722.67 toks/s, output: 10768.32 toks/s]\n",
      " 48%|████▊     | 12/25 [00:21<00:21,  1.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0009, 'grad_norm': 4.28125, 'learning_rate': 2.8000000000000003e-05, 'num_tokens': 115917.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 158.69375610351562, 'rewards/reward_from_env/std': 14.300506591796875, 'reward': 158.69375610351562, 'reward_std': 11.313620567321777, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06296081840991974, 'sampling/sampling_logp_difference/max': 2.0188417434692383, 'sampling/importance_sampling_ratio/min': 0.13280920684337616, 'sampling/importance_sampling_ratio/mean': 1.016083002090454, 'sampling/importance_sampling_ratio/max': 1.487716794013977, 'entropy': 0.671875, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.000732421875, 'clip_ratio/high_max': 0.000732421875, 'clip_ratio/region_mean': 0.0008544921875, 'epoch': 0.48}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:43 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1635.05it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 40.84it/s, est. speed input: 2021.76 toks/s, output: 10455.90 toks/s]\n",
      " 52%|█████▏    | 13/25 [00:22<00:20,  1.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0001, 'grad_norm': 4.71875, 'learning_rate': 2.6000000000000002e-05, 'num_tokens': 125693.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 163.359375, 'rewards/reward_from_env/std': 13.915802955627441, 'reward': 163.359375, 'reward_std': 12.928295135498047, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.07757525146007538, 'sampling/sampling_logp_difference/max': 2.3781018257141113, 'sampling/importance_sampling_ratio/min': 0.09272641688585281, 'sampling/importance_sampling_ratio/mean': 1.0198965072631836, 'sampling/importance_sampling_ratio/max': 1.5324656963348389, 'entropy': 0.89453125, 'clip_ratio/low_mean': 0.0008544921875, 'clip_ratio/low_min': 0.0008544921875, 'clip_ratio/high_mean': 0.000732421875, 'clip_ratio/high_max': 0.000732421875, 'clip_ratio/region_mean': 0.0015869140625, 'epoch': 0.52}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:44 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1258.51it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.02it/s, est. speed input: 3487.91 toks/s, output: 10504.36 toks/s]\n",
      " 56%|█████▌    | 14/25 [00:24<00:18,  1.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.059, 'grad_norm': 4.25, 'learning_rate': 2.4e-05, 'num_tokens': 136348.0, 'completions/mean_length': 249.96875, 'completions/min_length': 63.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.96875, 'completions/mean_terminated_length': 63.0, 'completions/min_terminated_length': 63.0, 'completions/max_terminated_length': 63.0, 'rewards/reward_from_env/mean': 167.2687530517578, 'rewards/reward_from_env/std': 26.248825073242188, 'reward': 167.2687530517578, 'reward_std': 19.7154598236084, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06686343997716904, 'sampling/sampling_logp_difference/max': 1.2733135223388672, 'sampling/importance_sampling_ratio/min': 0.2799026370048523, 'sampling/importance_sampling_ratio/mean': 1.0195963382720947, 'sampling/importance_sampling_ratio/max': 1.6750993728637695, 'entropy': 0.754233181476593, 'clip_ratio/low_mean': 0.0003662109375, 'clip_ratio/low_min': 0.0003662109375, 'clip_ratio/high_mean': 0.0006103515625, 'clip_ratio/high_max': 0.0006103515625, 'clip_ratio/region_mean': 0.0009765625, 'epoch': 0.56}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:46 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1904.34it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.01it/s, est. speed input: 1387.31 toks/s, output: 11012.33 toks/s]\n",
      " 60%|██████    | 15/25 [00:25<00:16,  1.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0007, 'grad_norm': 3.015625, 'learning_rate': 2.2000000000000003e-05, 'num_tokens': 145572.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 173.2843780517578, 'rewards/reward_from_env/std': 9.06617546081543, 'reward': 173.2843780517578, 'reward_std': 8.840656280517578, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.05370745807886124, 'sampling/sampling_logp_difference/max': 1.5851564407348633, 'sampling/importance_sampling_ratio/min': 0.20491573214530945, 'sampling/importance_sampling_ratio/mean': 1.0140749216079712, 'sampling/importance_sampling_ratio/max': 1.6439921855926514, 'entropy': 0.6044921875, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.0003662109375, 'clip_ratio/high_max': 0.0003662109375, 'clip_ratio/region_mean': 0.00048828125, 'epoch': 0.6}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:47 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1568.40it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.42it/s, est. speed input: 2683.29 toks/s, output: 10860.34 toks/s]\n",
      " 64%|██████▍   | 16/25 [00:27<00:14,  1.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0022, 'grad_norm': 3.6875, 'learning_rate': 2e-05, 'num_tokens': 155788.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 175.69375610351562, 'rewards/reward_from_env/std': 24.015649795532227, 'reward': 175.69375610351562, 'reward_std': 16.03937530517578, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.07281577587127686, 'sampling/sampling_logp_difference/max': 1.3870906829833984, 'sampling/importance_sampling_ratio/min': 0.24980100989341736, 'sampling/importance_sampling_ratio/mean': 1.0174846649169922, 'sampling/importance_sampling_ratio/max': 1.7054287195205688, 'entropy': 0.85546875, 'clip_ratio/low_mean': 0.000244140625, 'clip_ratio/low_min': 0.000244140625, 'clip_ratio/high_mean': 0.001220703125, 'clip_ratio/high_max': 0.001220703125, 'clip_ratio/region_mean': 0.00146484375, 'epoch': 0.64}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:49 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1722.51it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.04it/s, est. speed input: 2001.76 toks/s, output: 11020.36 toks/s]\n",
      " 68%|██████▊   | 17/25 [00:29<00:12,  1.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0012, 'grad_norm': 4.25, 'learning_rate': 1.8e-05, 'num_tokens': 165468.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 180.13125610351562, 'rewards/reward_from_env/std': 18.830968856811523, 'reward': 180.13125610351562, 'reward_std': 17.099414825439453, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.07025197893381119, 'sampling/sampling_logp_difference/max': 1.3835391998291016, 'sampling/importance_sampling_ratio/min': 0.2506897449493408, 'sampling/importance_sampling_ratio/mean': 1.0149098634719849, 'sampling/importance_sampling_ratio/max': 2.0, 'entropy': 0.7724609375, 'clip_ratio/low_mean': 0.0003662109375, 'clip_ratio/low_min': 0.0003662109375, 'clip_ratio/high_mean': 0.0003662109375, 'clip_ratio/high_max': 0.0003662109375, 'clip_ratio/region_mean': 0.000732421875, 'epoch': 0.68}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:51 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1262.58it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 41.95it/s, est. speed input: 3671.35 toks/s, output: 10741.24 toks/s]\n",
      " 72%|███████▏  | 18/25 [00:30<00:11,  1.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0013, 'grad_norm': 3.4375, 'learning_rate': 1.6000000000000003e-05, 'num_tokens': 176460.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 187.6062469482422, 'rewards/reward_from_env/std': 17.654661178588867, 'reward': 187.6062469482422, 'reward_std': 15.463003158569336, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06519210338592529, 'sampling/sampling_logp_difference/max': 1.522028923034668, 'sampling/importance_sampling_ratio/min': 0.21826860308647156, 'sampling/importance_sampling_ratio/mean': 1.0162379741668701, 'sampling/importance_sampling_ratio/max': 1.509002447128296, 'entropy': 0.748046875, 'clip_ratio/low_mean': 0.000244140625, 'clip_ratio/low_min': 0.000244140625, 'clip_ratio/high_mean': 0.0006103515625, 'clip_ratio/high_max': 0.0006103515625, 'clip_ratio/region_mean': 0.0008544921875, 'epoch': 0.72}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:52 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1816.50it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.03it/s, est. speed input: 1484.77 toks/s, output: 11017.33 toks/s]\n",
      " 76%|███████▌  | 19/25 [00:32<00:09,  1.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0013, 'grad_norm': 2.859375, 'learning_rate': 1.4000000000000001e-05, 'num_tokens': 185756.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 193.078125, 'rewards/reward_from_env/std': 13.261190414428711, 'reward': 193.078125, 'reward_std': 10.372503280639648, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.05257996916770935, 'sampling/sampling_logp_difference/max': 1.1692800521850586, 'sampling/importance_sampling_ratio/min': 0.3105904757976532, 'sampling/importance_sampling_ratio/mean': 1.0121604204177856, 'sampling/importance_sampling_ratio/max': 1.7386900186538696, 'entropy': 0.6025390625, 'clip_ratio/low_mean': 0.000244140625, 'clip_ratio/low_min': 0.000244140625, 'clip_ratio/high_mean': 0.00048828125, 'clip_ratio/high_max': 0.00048828125, 'clip_ratio/region_mean': 0.000732421875, 'epoch': 0.76}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:54 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 2023.55it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.07it/s, est. speed input: 1852.04 toks/s, output: 11026.02 toks/s]\n",
      " 80%|████████  | 20/25 [00:33<00:08,  1.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0006, 'grad_norm': 2.984375, 'learning_rate': 1.2e-05, 'num_tokens': 195324.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 200.0968780517578, 'rewards/reward_from_env/std': 11.700551986694336, 'reward': 200.0968780517578, 'reward_std': 10.329671859741211, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.056813858449459076, 'sampling/sampling_logp_difference/max': 1.5537490844726562, 'sampling/importance_sampling_ratio/min': 0.21145372092723846, 'sampling/importance_sampling_ratio/mean': 1.0133092403411865, 'sampling/importance_sampling_ratio/max': 1.617047667503357, 'entropy': 0.6494140625, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0001220703125, 'clip_ratio/high_max': 0.0001220703125, 'clip_ratio/region_mean': 0.0001220703125, 'epoch': 0.8}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:55 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1563.00it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.59it/s, est. speed input: 3343.48 toks/s, output: 10903.49 toks/s]\n",
      " 84%|████████▍ | 21/25 [00:35<00:06,  1.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0011, 'grad_norm': 2.96875, 'learning_rate': 1e-05, 'num_tokens': 206028.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 199.984375, 'rewards/reward_from_env/std': 8.327726364135742, 'reward': 199.984375, 'reward_std': 7.488068103790283, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.06027976796030998, 'sampling/sampling_logp_difference/max': 1.3746700286865234, 'sampling/importance_sampling_ratio/min': 0.25292304158210754, 'sampling/importance_sampling_ratio/mean': 1.0146417617797852, 'sampling/importance_sampling_ratio/max': 1.6927820444107056, 'entropy': 0.68359375, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.00048828125, 'clip_ratio/high_max': 0.00048828125, 'clip_ratio/region_mean': 0.0006103515625, 'epoch': 0.84}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:57 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1648.22it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.39it/s, est. speed input: 1897.13 toks/s, output: 10852.79 toks/s]\n",
      " 88%|████████▊ | 22/25 [00:37<00:04,  1.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0009, 'grad_norm': 2.90625, 'learning_rate': 8.000000000000001e-06, 'num_tokens': 215652.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 204.25625610351562, 'rewards/reward_from_env/std': 14.855082511901855, 'reward': 204.25625610351562, 'reward_std': 14.283214569091797, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.05508602410554886, 'sampling/sampling_logp_difference/max': 1.4210529327392578, 'sampling/importance_sampling_ratio/min': 0.2414596527814865, 'sampling/importance_sampling_ratio/mean': 1.0129733085632324, 'sampling/importance_sampling_ratio/max': 1.617936611175537, 'entropy': 0.6337890625, 'clip_ratio/low_mean': 0.0003662109375, 'clip_ratio/low_min': 0.0003662109375, 'clip_ratio/high_mean': 0.000732421875, 'clip_ratio/high_max': 0.000732421875, 'clip_ratio/region_mean': 0.0010986328125, 'epoch': 0.88}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:32:59 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1764.72it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.75it/s, est. speed input: 2244.40 toks/s, output: 10944.05 toks/s]\n",
      " 92%|█████████▏| 23/25 [00:38<00:03,  1.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0024, 'grad_norm': 3.3125, 'learning_rate': 6e-06, 'num_tokens': 225524.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 194.9343719482422, 'rewards/reward_from_env/std': 9.946619033813477, 'reward': 194.9343719482422, 'reward_std': 9.08108139038086, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.048330195248126984, 'sampling/sampling_logp_difference/max': 1.457280158996582, 'sampling/importance_sampling_ratio/min': 0.23286877572536469, 'sampling/importance_sampling_ratio/mean': 1.012702226638794, 'sampling/importance_sampling_ratio/max': 1.5781943798065186, 'entropy': 0.56689453125, 'clip_ratio/low_mean': 0.0006103515625, 'clip_ratio/low_min': 0.0006103515625, 'clip_ratio/high_mean': 0.0003662109375, 'clip_ratio/high_max': 0.0003662109375, 'clip_ratio/region_mean': 0.0009765625, 'epoch': 0.92}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:33:01 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1262.58it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 43.00it/s, est. speed input: 2193.11 toks/s, output: 11008.46 toks/s]\n",
      " 96%|█████████▌| 24/25 [00:40<00:01,  1.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.0001, 'grad_norm': 3.0, 'learning_rate': 4.000000000000001e-06, 'num_tokens': 235348.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 203.49061584472656, 'rewards/reward_from_env/std': 12.524025917053223, 'reward': 203.49061584472656, 'reward_std': 11.884631156921387, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.05112988501787186, 'sampling/sampling_logp_difference/max': 1.372579574584961, 'sampling/importance_sampling_ratio/min': 0.2534523010253906, 'sampling/importance_sampling_ratio/mean': 1.0137460231781006, 'sampling/importance_sampling_ratio/max': 1.5087206363677979, 'entropy': 0.57177734375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.00048828125, 'clip_ratio/high_max': 0.00048828125, 'clip_ratio/region_mean': 0.00048828125, 'epoch': 0.96}\n",
      "\u001b[1;36m(EngineCore_DP0 pid=3787957)\u001b[0;0m INFO 10-31 17:33:02 [block_pool.py:292] Successfully reset prefix cache\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Adding requests: 100%|██████████| 4/4 [00:00<00:00, 1398.22it/s]\n",
      "Processed prompts: 100%|██████████| 32/32 [00:00<00:00, 42.75it/s, est. speed input: 3110.32 toks/s, output: 10944.83 toks/s]\n",
      "100%|██████████| 25/25 [00:42<00:00,  1.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': -0.0005, 'grad_norm': 3.0, 'learning_rate': 2.0000000000000003e-06, 'num_tokens': 245868.0, 'completions/mean_length': 256.0, 'completions/min_length': 256.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_from_env/mean': 199.96875, 'rewards/reward_from_env/std': 10.715091705322266, 'reward': 199.96875, 'reward_std': 8.32847785949707, 'frac_reward_zero_std': 0.0, 'sampling/sampling_logp_difference/mean': 0.05209578573703766, 'sampling/sampling_logp_difference/max': 1.2726020812988281, 'sampling/importance_sampling_ratio/min': 0.28010183572769165, 'sampling/importance_sampling_ratio/mean': 1.0139122009277344, 'sampling/importance_sampling_ratio/max': 1.518819808959961, 'entropy': 0.5849609375, 'clip_ratio/low_mean': 0.0001220703125, 'clip_ratio/low_min': 0.0001220703125, 'clip_ratio/high_mean': 0.0001220703125, 'clip_ratio/high_max': 0.0001220703125, 'clip_ratio/region_mean': 0.000244140625, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:45<00:00,  1.83s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_runtime': 45.7829, 'train_samples_per_second': 2.184, 'train_steps_per_second': 0.546, 'train_loss': -0.023757427856326105, 'epoch': 1.0}\n",
      "[2025-10-31 17:33:07,390][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:561] Training is Complete.\n",
      "[2025-10-31 17:33:07,391][oumi][rank0][pid:3788270][MainThread][INFO]][device_utils.py:343] GPU Metrics After Training: GPU runtime info: NVidiaGpuRuntimeInfo(device_index=0, device_count=2, used_memory_mb=75603.0, temperature=34, fan_speed=None, fan_speeds=None, power_usage_watts=124.915, power_limit_watts=700.0, gpu_utilization=0, memory_utilization=0, performance_state=0, clock_speed_graphics=1980, clock_speed_sm=1980, clock_speed_memory=2619).\n",
      "[2025-10-31 17:33:07,391][oumi][rank0][pid:3788270][MainThread][INFO]][torch_utils.py:135] Peak GPU memory usage: 10.15 GB\n",
      "[2025-10-31 17:33:07,391][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:568] Saving final state...\n",
      "[2025-10-31 17:33:07,395][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:573] Saving final model...\n",
      "[2025-10-31 17:33:09,102][oumi][rank0][pid:3788270][MainThread][INFO]][hf_trainer.py:127] Model has been saved at openenv_tutorial/echo_grpo\n",
      "[2025-10-31 17:33:09,102][oumi][rank0][pid:3788270][MainThread][INFO]][train.py:222] \n",
      "\n",
      "» We're always looking for feedback. What's one thing we can improve? https://oumi.ai/feedback\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CompletedProcess(args=['/home/wizeng/miniconda3/envs/openenv/bin/python', 'openenv_tutorial/train.py'], returncode=0)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import subprocess\n",
    "import sys\n",
    "\n",
    "env = {**os.environ, \"CUDA_VISIBLE_DEVICES\": \"1\"}\n",
    "# Run the trainer as a subprocess to reinitialize CUDA with only the second GPU visible.\n",
    "subprocess.run(\n",
    "    [sys.executable, str(Path(tutorial_dir) / \"train.py\")], env=env, check=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you enabled wandb logging, you should get a reward graph that looks like this. Even though the training duration was short, we can see that the model quickly learned to maximize the reward.\n",
    "\n",
    "\n",
    "![Echo env reward graph](./assets/openenv_echo_reward.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🧭 What's Next?\n",
    "\n",
    "Congrats on finishing this notebook! Feel free to check out our other [notebooks](https://github.com/oumi-ai/oumi/tree/main/notebooks) in the [Oumi GitHub](https://github.com/oumi-ai/oumi), and give us a star! You can also join the Oumi community over on [Discord](https://discord.gg/oumi).\n",
    "\n",
    "📰 Want to keep up with news from Oumi? Subscribe to our [Substack](https://blog.oumi.ai/) and [Youtube](https://www.youtube.com/@Oumi_AI)!\n",
    "\n",
    "⚡ Interested in building custom AI in hours, not months? Apply to get [early access](https://oumi.ai/contact?utm_source=oumi_oss_tutorial_openenv) to the Oumi Platform, or [chat with us](https://oumi.ai/book?utm_source=oumi_oss_tutorial_openenv) to learn more!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "openenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
