Source code for oumi.core.configs.params.environment_params
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for a single agentic environment."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.grounding_params import GroundingConfig
from oumi.core.configs.params.tool_params import ToolParams
from oumi.utils.logging import logger
[docs]
@dataclass
class EnvironmentParams(BaseParams):
"""Pure-data description of an environment."""
id: str = ""
name: str = ""
description: str = ""
env_type: str = ""
tools: list[Any] = field(default_factory=list)
env_kwargs: dict[str, Any] | None = None
grounding: GroundingConfig | None = None
[docs]
def __post_init__(self) -> None:
"""Coerce raw tool dicts and grounding config."""
tool_cls = self._resolve_tool_cls() or ToolParams
self.tools = [
tool if isinstance(tool, tool_cls) else tool_cls.create(tool)
for tool in self.tools
]
if self.grounding is not None and not isinstance(
self.grounding, GroundingConfig
):
self.grounding = GroundingConfig(**self.grounding)
def _resolve_tool_cls(self) -> type[ToolParams] | None:
"""Look up the registered env class, return its tool_params_cls.
Returns None if env_type isn't registered (validation later catches it).
Lazy import to avoid circular dependency between
`core/configs/params/` and `environments/`.
"""
from oumi.core.registry import REGISTRY, RegistryType
env_cls = REGISTRY.get(self.env_type, RegistryType.ENVIRONMENT)
return getattr(env_cls, "tool_params_cls", None) if env_cls else None
[docs]
def __finalize_and_validate__(self) -> None:
"""Validate common fields and registry membership."""
if not self.id:
raise ValueError(f"{type(self).__name__}.id cannot be empty.")
if not self.name:
raise ValueError(f"{type(self).__name__}.name cannot be empty.")
if not self.description:
raise ValueError(f"{type(self).__name__}.description cannot be empty.")
if not self.env_type:
raise ValueError(f"{type(self).__name__}.env_type cannot be empty.")
if self.env_kwargs is not None and not isinstance(self.env_kwargs, Mapping):
raise ValueError(
f"{type(self).__name__}.env_kwargs must be a mapping or None, "
f"got {type(self.env_kwargs).__name__}."
)
self._validate_unique_tool_ids()
self._validate_env_type_registered()
self._validate_grounding_has_tools()
self._warn_on_stale_grounding_tool_ids()
def _validate_unique_tool_ids(self) -> None:
seen: set[str] = set()
for tool in self.tools:
if tool.id in seen:
raise ValueError(
f"{type(self).__name__} '{self.id}' contains duplicate "
f"tool id '{tool.id}'."
)
seen.add(tool.id)
def _validate_env_type_registered(self) -> None:
from oumi.core.registry import REGISTRY, RegistryType
if not REGISTRY.contains(self.env_type, RegistryType.ENVIRONMENT):
known = sorted(REGISTRY.get_all(RegistryType.ENVIRONMENT))
raise ValueError(
f"Unknown env_type '{self.env_type}'. Known types: {known}"
)
def _validate_grounding_has_tools(self) -> None:
"""If env-level grounding is set, ``grounding.tools`` must be non-empty."""
if self.grounding is None:
return
if not self.grounding.tools:
raise ValueError(
f"{type(self).__name__} '{self.id}' declares grounding but "
f"grounding.tools is empty. Add at least one tool entry, or "
f"remove env-level grounding."
)
def _warn_on_stale_grounding_tool_ids(self) -> None:
"""Log a warning for ``grounding.tools`` entries naming unknown tools."""
if self.grounding is None:
return
tool_ids = {tool.id for tool in self.tools}
for tool_id in self.grounding.tools:
if tool_id not in tool_ids:
logger.warning(
"Environment '%s': grounding.tools.'%s' references unknown "
"tool. Entry will be ignored.",
self.id,
tool_id,
)