import asyncio
import datetime
import inspect
import logging
from kani import AIFunction, ChatMessage
from . import events
from .base_kani import BaseKani
from .delegation import DelegationBase
from .namer import Namer
from .tool_config import ToolConfigType
from .tools import ToolBase
log = logging.getLogger(__name__)
# ==== prompts ====
DEFAULT_ROOT_PROMPT = (
"# Goals\n\nYour goal is to answer the user's questions and help them out by performing actions. While you may be"
" able to answer many questions from memory alone, the user's queries will sometimes require you to take actions."
" You can use the provided function to ask your capable helpers, who can help you take actions.\nThe current time"
" is {time}."
)
DEFAULT_DELEGATE_PROMPT = (
"You are {name}, a helpful assistant with the goal of answering the user's questions as precisely as possible and"
" helping them out by performing actions.\nYou can use the provided functions to take actions yourself or break"
" queries up into smaller pieces and ask your capable helpers, who can help you.\nIf the user's query involves"
" multiple steps, you should break it up into smaller pieces and delegate those pieces - for example, if you need"
" to look up multiple sites, delegate each search to a helper. Say your plan before you do. If those pieces can be"
" resolved at the same time, delegate them all at once. You may do multiple rounds of delegating for additional"
" steps that depend on earlier steps.\nThe current time is {time}."
)
def get_system_prompt(kani: "BaseKani") -> str:
"""Fill in the system prompt template from the kani."""
now = datetime.datetime.now().strftime("%a %d %b %Y, %I:%M%p")
return kani.system_prompt.format(name=kani.name, time=now)
# ==== implementation ====
[docs]
class ReDelKani(BaseKani):
"""Base class for recursive delegation kanis. Extends :class:`.BaseKani`.
This class should not be constructed manually - it is tightly coupled to and managed by the application. You can
get a reference to a kani powering an agent in a tool by using :attr:`.ToolBase.kani`.
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault("retry_attempts", 10)
super().__init__(*args, **kwargs)
self.namer = Namer()
self.delegator = None
self.tools = []
def _register_tools(self, delegator: DelegationBase | None, tools: list[ToolBase]):
"""Overwrite this kani's functions with the functions provided by the given delegation scheme and tools.
Should be called only once, immediately after __init__.
"""
new_functions = {}
# find all registered ai_functions in the delegation scheme and tools and save them
self.delegator = delegator
if delegator:
new_functions.update(get_tool_functions(delegator))
self.tools = tools
for inst in tools:
new_functions.update(get_tool_functions(inst))
self.functions = new_functions
async def create_delegate_kani(self, instructions: str):
# create the new instance
name = self.namer.get_name()
kani_inst = ReDelKani(
self.app.delegate_engine,
# app args
app=self.app,
parent=self,
name=name,
dispatch_creation=False,
# kani args
system_prompt=self.app.delegate_system_prompt,
**self.app.delegate_kani_kwargs,
)
await self.register_child_kani(kani_inst, instructions)
self.app.dispatch(
events.KaniDelegated(
parent_id=self.id,
child_id=kani_inst.id,
parent_message_idx=len(self.chat_history) - 1,
child_message_idx=len(kani_inst.chat_history),
instructions=instructions,
)
)
return kani_inst
async def register_child_kani(self, kani_inst, instructions: str | None):
# set up tools
# delegation
if self.app.delegation_scheme is None or self.depth == self.app.max_delegation_depth:
delegation_scheme_inst = None
else:
delegation_scheme_inst = self.app.delegation_scheme(app=self.app, kani=kani_inst)
# tools, TODO with the retrieved functions to use if instructions are given
tool_insts = []
for t, config in self.app.tool_configs.items():
if config.get("always_include", False):
tool_insts.append(t(app=self.app, kani=kani_inst, **config.get("kwargs", {})))
# noinspection PyProtectedMember
kani_inst._register_tools(delegator=delegation_scheme_inst, tools=tool_insts)
if delegation_scheme_inst:
await delegation_scheme_inst.setup()
await asyncio.gather(*(t.setup() for t in tool_insts))
# bookkeeping
self.app.on_kani_creation(kani_inst)
# overrides
[docs]
async def get_prompt(self) -> list[ChatMessage]:
# if we have a system prompt, update it with any time/name templates
if self.system_prompt is not None:
self.always_included_messages[0] = ChatMessage.system(get_system_prompt(self))
return await super().get_prompt()
[docs]
async def cleanup(self):
if self.delegator:
await self.delegator.cleanup()
await asyncio.gather(*(t.cleanup() for t in self.tools))
await super().cleanup()
[docs]
async def close(self):
if self.delegator:
await self.delegator.close()
await asyncio.gather(*(t.close() for t in self.tools))
await super().close()
async def create_root_kani(
*args,
app,
delegation_scheme: type[DelegationBase] | None,
tool_configs: ToolConfigType,
root_has_tools: bool,
**kwargs,
) -> ReDelKani:
"""Create the root kani for the kani delegation tree."""
kani_inst = ReDelKani(*args, app=app, dispatch_creation=False, **kwargs)
# delegation
if delegation_scheme:
delegation_scheme_inst = delegation_scheme(app=app, kani=kani_inst)
else:
delegation_scheme_inst = None
# tools
tool_insts = []
for t, config in tool_configs.items():
if config.get("always_include_root", False) or (config.get("always_include", False) and root_has_tools):
tool_insts.append(t(app=app, kani=kani_inst, **config.get("kwargs", {})))
# noinspection PyProtectedMember
kani_inst._register_tools(delegator=delegation_scheme_inst, tools=tool_insts)
if delegation_scheme_inst:
await delegation_scheme_inst.setup()
await asyncio.gather(*(t.setup() for t in tool_insts))
app.on_kani_creation(kani_inst)
return kani_inst
def get_tool_functions(inst: ToolBase) -> dict[str, AIFunction]:
functions = {}
for name, member in inspect.getmembers(inst, predicate=inspect.ismethod):
if not hasattr(member, "__ai_function__"):
continue
f = AIFunction(member, **member.__ai_function__)
if f.name in functions:
raise ValueError(f"AIFunction {f.name!r} is already registered!")
functions[f.name] = f
return functions