Source code for redel.base_kani

from contextlib import contextmanager
from typing import TYPE_CHECKING, AsyncIterable
from weakref import WeakValueDictionary

from kani import ChatMessage, ChatRole, Kani
from kani.engines.base import BaseCompletion
from kani.engines.openai import OpenAIEngine
from kani.streaming import StreamManager

from . import events
from .state import KaniState, RunState
from .utils import create_kani_id

if TYPE_CHECKING:
    from .app import ReDel


[docs] class BaseKani(Kani): """ Base class for all kani in the application, regardless of recursive delegation. Extends :class:`kani.Kani`. See the Kani documentation for more details on the internal chat state and LLM interface. """ def __init__( self, *args, app: "ReDel", parent: "BaseKani" = None, id: str = None, name: str = None, dispatch_creation: bool = True, **kwargs, ): """ :param app: The :class:`.ReDel` instance this kani is a part of. :param parent: The parent of this kani, or ``None`` if this is the root of a system. :param id: The internal ID of this kani. If not passed, generates a UUID. :param name: The human-readable name of this kani. If not passed, uses the ID. :param dispatch_creation: Whether to dispatch a :class:`.events.KaniSpawn` event automatically. If false, the caller is responsible for calling ``app.on_kani_creation()`` to dispatch the event. """ super().__init__(*args, **kwargs) self.state = RunState.STOPPED self._old_state_stack = [] # tree management if parent is not None: self.depth = parent.depth + 1 else: self.depth = 0 self.parent = parent self.children = WeakValueDictionary() # app management self.id = create_kani_id() if id is None else id self.name = self.id if name is None else name self.app = app if dispatch_creation: app.on_kani_creation(self) # ==== overrides ==== async def get_model_completion(self, include_functions: bool = True, **kwargs) -> BaseCompletion: # if include_functions is False but we have functions and are using an OpenAIEngine, we should set # tool_choice="none" instead -- this prevents the API from exploding if we set parallel_tool_calls if self.functions and (not include_functions) and isinstance(self.engine, OpenAIEngine): include_functions = True kwargs["tool_choice"] = "none" return await super().get_model_completion(include_functions=include_functions, **kwargs) async def get_model_stream(self, include_functions: bool = True, **kwargs) -> AsyncIterable[str | BaseCompletion]: # same as above for streaming if self.functions and (not include_functions) and isinstance(self.engine, OpenAIEngine): include_functions = True kwargs["tool_choice"] = "none" async for elem in super().get_model_stream(include_functions=include_functions, **kwargs): yield elem async def chat_round(self, *args, **kwargs): with self.run_state(RunState.RUNNING): return await super().chat_round(*args, **kwargs) def chat_round_stream(self, *args, **kwargs) -> StreamManager: stream = super().chat_round_stream(*args, **kwargs) # consume from the inner StreamManager and re-yield with bookkeeping async def _impl(): with self.run_state(RunState.RUNNING): async for token in stream: yield token self.app.dispatch(events.StreamDelta(id=self.id, delta=token, role=stream.role)) yield await stream.completion() return StreamManager(_impl(), role=stream.role) async def full_round(self, *args, **kwargs): with self.run_state(RunState.RUNNING): async for msg in super().full_round(*args, **kwargs): yield msg async def full_round_stream(self, *args, **kwargs) -> AsyncIterable[StreamManager]: with self.run_state(RunState.RUNNING): async for stream in super().full_round_stream(*args, **kwargs): # consume from the inner StreamManager and re-yield with bookkeeping async def _impl(): async for token in stream: yield token self.app.dispatch(events.StreamDelta(id=self.id, delta=token, role=stream.role)) yield await stream.completion() yield StreamManager(_impl(), role=stream.role) async def add_to_history(self, message: ChatMessage): await super().add_to_history(message) self.app.dispatch(events.KaniMessage(id=self.id, msg=message)) if self.parent is None: self.app.dispatch(events.RootMessage(msg=message)) async def add_completion_to_history(self, completion): message = await super().add_completion_to_history(completion) self.app.dispatch( events.TokensUsed( id=self.id, prompt_tokens=completion.prompt_tokens, completion_tokens=completion.completion_tokens ) ) # HACK: sometimes openai's function calls are borked; we fix them here if message.tool_calls: for tc in message.tool_calls: if (function_call := tc.function) and function_call.name.startswith("functions."): function_call.name = function_call.name.removeprefix("functions.") return message # ==== utils ==== @property def last_user_message(self) -> ChatMessage | None: """The most recent USER message in this kani's chat history, if one exists.""" return next((m for m in reversed(self.chat_history) if m.role == ChatRole.USER), None) @property def last_assistant_message(self) -> ChatMessage | None: """The most recent ASSISTANT message in this kani's chat history, if one exists.""" return next((m for m in reversed(self.chat_history) if m.role == ChatRole.ASSISTANT), None) def get_save_state(self) -> KaniState: """Get a Pydantic state suitable for saving/loading.""" return KaniState.from_kani(self) # --- state utils ---
[docs] def set_run_state(self, state: RunState): """Set the run state and dispatch the event.""" # noop if we're already in that state if self.state == state: return self.state = state self.app.dispatch(events.KaniStateChange(id=self.id, state=self.state))
[docs] @contextmanager def run_state(self, state: RunState): """Run the body of this statement with a different run state then set it back after.""" self._old_state_stack.append(self.state) self.set_run_state(state) try: yield finally: self.set_run_state(self._old_state_stack.pop())
[docs] async def cleanup(self): """This kani may run again but is done for now; clean up any ephemeral resources but save its state.""" pass
[docs] async def close(self): """The application is shutting down and all resources should be gracefully cleaned up.""" pass