Source code for oumi.environments.synthetic_environment
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic environment backed by LLM-simulated tool execution."""
from __future__ import annotations
import copy
import json
from dataclasses import dataclass
from typing import Any
import jsonschema
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.environment_params import EnvironmentParams
from oumi.core.configs.params.tool_params import ToolParams
from oumi.core.registry import register_environment
from oumi.core.types.tool_call import ToolResult
from oumi.environments.base_environment import BaseEnvironment
[docs]
@dataclass
class SyntheticStateParams(BaseParams):
"""Optional state configuration for a synthetic environment."""
state_schema: dict[str, Any] | None = None
initial_state: dict[str, Any] | None = None
[docs]
def __post_init__(self):
"""Validate state config consistency."""
if self.state_schema is not None and self.initial_state is not None:
jsonschema.validate(self.initial_state, self.state_schema)
[docs]
@dataclass
class SyntheticEnvironmentKwargs(BaseParams):
"""Type-specific kwargs for SyntheticEnvironment."""
system_prompt: str = ""
state_params: SyntheticStateParams | None = None
cache_by_input: bool = True
[docs]
def __post_init__(self) -> None:
"""Coerce state_params dict into SyntheticStateParams if needed."""
if isinstance(self.state_params, dict):
self.state_params = SyntheticStateParams(**self.state_params)
[docs]
def __finalize_and_validate__(self) -> None:
"""Finalize and validate the kwargs."""
if not self.system_prompt:
raise ValueError(
"SyntheticEnvironmentKwargs.system_prompt cannot be empty."
)
if self.state_params is not None and self.cache_by_input:
raise ValueError(
"SyntheticEnvironmentKwargs.cache_by_input must be False when "
"state_params is provided."
)
[docs]
@register_environment("synthetic")
class SyntheticEnvironment(BaseEnvironment):
"""LLM-simulated environment with optional mutable state."""
def __init__(
self,
params: EnvironmentParams,
kwargs: SyntheticEnvironmentKwargs,
) -> None:
"""Initialize a SyntheticEnvironment with the given params and kwargs."""
self._params = params
self._kwargs = kwargs
self._cache: dict[str, ToolResult] = {}
self._state: dict[str, Any] | None = (
copy.deepcopy(kwargs.state_params.initial_state)
if kwargs.state_params is not None
and kwargs.state_params.initial_state is not None
else None
)
[docs]
@classmethod
def from_params(cls, params: EnvironmentParams) -> SyntheticEnvironment:
"""Build a SyntheticEnvironment from its params object."""
kwargs = SyntheticEnvironmentKwargs(**(params.env_kwargs or {}))
kwargs.finalize_and_validate()
return cls(params, kwargs)
@property
def current_state(self) -> dict[str, Any] | None:
"""Return the current in-memory state snapshot."""
if self._state is None:
return None
return copy.deepcopy(self._state)
@staticmethod
def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str:
"""Build a stable cache key from tool id and arguments."""
return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}"
def _resolve_cached(
self, tool_id: str, arguments: dict[str, Any]
) -> ToolResult | None:
"""Look up a cached result for the given tool call."""
if not self._kwargs.cache_by_input:
return None
result = self._cache.get(self._cache_key(tool_id, arguments))
if result is None:
return None
return ToolResult(
output=copy.deepcopy(result.output),
updated_state=copy.deepcopy(result.updated_state),
)
def _cache_result(
self, tool_id: str, arguments: dict[str, Any], result: ToolResult
) -> None:
"""Store a generated result in the cache."""
if not self._kwargs.cache_by_input:
return
self._cache[self._cache_key(tool_id, arguments)] = ToolResult(
output=copy.deepcopy(result.output),
updated_state=copy.deepcopy(result.updated_state),
)
def _lookup_tool(self, tool_id: str) -> ToolParams:
for tool in self._params.tools:
if tool.id == tool_id:
return tool
raise ValueError(
f"Tool '{tool_id}' not found in environment '{self._params.id}'. "
f"Available tools: {[tool.id for tool in self._params.tools]}"
)
[docs]
def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult:
"""Execute a synthetic tool call."""
self._lookup_tool(tool_id)
raise NotImplementedError("SyntheticEnvironment.step() is not implemented yet.")