Source code for oumi.core.configs.params.environment_params

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration for a single agentic environment."""

from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any

from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.grounding_params import GroundingConfig
from oumi.core.configs.params.tool_params import ToolParams
from oumi.utils.logging import logger


[docs] @dataclass class EnvironmentParams(BaseParams): """Pure-data description of an environment.""" id: str = "" name: str = "" description: str = "" env_type: str = "" tools: list[Any] = field(default_factory=list) env_kwargs: dict[str, Any] | None = None grounding: GroundingConfig | None = None
[docs] def __post_init__(self) -> None: """Coerce raw tool dicts and grounding config.""" tool_cls = self._resolve_tool_cls() or ToolParams self.tools = [ tool if isinstance(tool, tool_cls) else tool_cls.create(tool) for tool in self.tools ] if self.grounding is not None and not isinstance( self.grounding, GroundingConfig ): self.grounding = GroundingConfig(**self.grounding)
def _resolve_tool_cls(self) -> type[ToolParams] | None: """Look up the registered env class, return its tool_params_cls. Returns None if env_type isn't registered (validation later catches it). Lazy import to avoid circular dependency between `core/configs/params/` and `environments/`. """ from oumi.core.registry import REGISTRY, RegistryType env_cls = REGISTRY.get(self.env_type, RegistryType.ENVIRONMENT) return getattr(env_cls, "tool_params_cls", None) if env_cls else None
[docs] def __finalize_and_validate__(self) -> None: """Validate common fields and registry membership.""" if not self.id: raise ValueError(f"{type(self).__name__}.id cannot be empty.") if not self.name: raise ValueError(f"{type(self).__name__}.name cannot be empty.") if not self.description: raise ValueError(f"{type(self).__name__}.description cannot be empty.") if not self.env_type: raise ValueError(f"{type(self).__name__}.env_type cannot be empty.") if self.env_kwargs is not None and not isinstance(self.env_kwargs, Mapping): raise ValueError( f"{type(self).__name__}.env_kwargs must be a mapping or None, " f"got {type(self.env_kwargs).__name__}." ) self._validate_unique_tool_ids() self._validate_env_type_registered() self._validate_grounding_has_tools() self._warn_on_stale_grounding_tool_ids()
def _validate_unique_tool_ids(self) -> None: seen: set[str] = set() for tool in self.tools: if tool.id in seen: raise ValueError( f"{type(self).__name__} '{self.id}' contains duplicate " f"tool id '{tool.id}'." ) seen.add(tool.id) def _validate_env_type_registered(self) -> None: from oumi.core.registry import REGISTRY, RegistryType if not REGISTRY.contains(self.env_type, RegistryType.ENVIRONMENT): known = sorted(REGISTRY.get_all(RegistryType.ENVIRONMENT)) raise ValueError( f"Unknown env_type '{self.env_type}'. Known types: {known}" ) def _validate_grounding_has_tools(self) -> None: """If env-level grounding is set, ``grounding.tools`` must be non-empty.""" if self.grounding is None: return if not self.grounding.tools: raise ValueError( f"{type(self).__name__} '{self.id}' declares grounding but " f"grounding.tools is empty. Add at least one tool entry, or " f"remove env-level grounding." ) def _warn_on_stale_grounding_tool_ids(self) -> None: """Log a warning for ``grounding.tools`` entries naming unknown tools.""" if self.grounding is None: return tool_ids = {tool.id for tool in self.tools} for tool_id in self.grounding.tools: if tool_id not in tool_ids: logger.warning( "Environment '%s': grounding.tools.'%s' references unknown " "tool. Entry will be ignored.", self.id, tool_id, )