Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(

ChatCompletionsOptions options = executeHook(
invocationContext,
kernel,
new PreChatCompletionEvent(
getCompletionsOptions(
this,
Expand Down Expand Up @@ -349,7 +350,7 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
.collect(Collectors.toList());

// execute post chat completion hook
executeHook(invocationContext, new PostChatCompletionEvent(completions));
executeHook(invocationContext, kernel, new PostChatCompletionEvent(completions));

// Just return the result:
// If we don't want to attempt to invoke any functions
Expand Down Expand Up @@ -517,11 +518,12 @@ private Mono<FunctionResult<String>> invokeFunctionTool(
pluginName,
openAIFunctionToolCall.getFunctionName());

PreToolCallEvent hookResult = executeHook(invocationContext, new PreToolCallEvent(
openAIFunctionToolCall.getFunctionName(),
openAIFunctionToolCall.getArguments(),
function,
contextVariableTypes));
PreToolCallEvent hookResult = executeHook(invocationContext, kernel,
new PreToolCallEvent(
openAIFunctionToolCall.getFunctionName(),
openAIFunctionToolCall.getArguments(),
function,
contextVariableTypes));

function = hookResult.getFunction();
KernelFunctionArguments arguments = hookResult.getArguments();
Expand All @@ -537,12 +539,21 @@ private Mono<FunctionResult<String>> invokeFunctionTool(

private static <T extends KernelHookEvent> T executeHook(
@Nullable InvocationContext invocationContext,
@Nullable Kernel kernel,
T event) {
KernelHooks kernelHooks = invocationContext != null
&& invocationContext.getKernelHooks() != null
? invocationContext.getKernelHooks()
: new KernelHooks();

KernelHooks kernelHooks = null;
if (kernel == null) {
if (invocationContext != null) {
kernelHooks = invocationContext.getKernelHooks();
}
} else {
kernelHooks = KernelHooks.merge(
kernel.getGlobalKernelHooks(),
invocationContext != null ? invocationContext.getKernelHooks() : null);
}
if (kernelHooks == null) {
return event;
}
return kernelHooks.executeHooks(event);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
import com.microsoft.semantickernel.hooks.KernelHooks;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.InvocationContext.Builder;
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
Expand Down Expand Up @@ -73,6 +74,27 @@ public static void main(String[] args) throws Exception {
.toPromptString(new Gson()::toJson)
.build());

KernelHooks hook = new KernelHooks();

hook.addPreToolCallHook((context) -> {
System.out.println("Pre-tool call hook");
return context;
});

hook.addPreChatCompletionHook(
(context) -> {
System.out.println("Pre-chat completion hook");
return context;
});

hook.addPostChatCompletionHook(
(context) -> {
System.out.println("Post-chat completion hook");
return context;
});

kernel.getGlobalKernelHooks().addHooks(hook);

// Enable planning
InvocationContext invocationContext = new Builder()
.withReturnMode(InvocationReturnMode.LAST_MESSAGE_ONLY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ public static void main(String[] args) {
inMemoryDataStorage(embeddingGeneration);
}

public static void inMemoryDataStorage(OpenAITextEmbeddingGenerationService embeddingGeneration) {
public static void inMemoryDataStorage(
OpenAITextEmbeddingGenerationService embeddingGeneration) {
// Create a new Volatile vector store
var volatileVectorStore = new VolatileVectorStore();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public UnmodifiableKernelHooks unmodifiableClone() {
*
* @return an unmodifiable map of the hooks
*/
private Map<String, KernelHook<?>> getHooks() {
protected Map<String, KernelHook<?>> getHooks() {
return Collections.unmodifiableMap(hooks);
}

Expand Down Expand Up @@ -224,6 +224,31 @@ public boolean isEmpty() {
return hooks.isEmpty();
}

/**
* Builds the list of hooks to be invoked for the given context, by merging the hooks in this
* collection with the hooks in the context. Duplicate hooks in b will override hooks in a.
*
* @param a hooks to merge
* @param b hooks to merge
* @return the list of hooks to be invoked
*/
public static KernelHooks merge(@Nullable KernelHooks a, @Nullable KernelHooks b) {
KernelHooks hooks = a;
if (hooks == null) {
hooks = new KernelHooks();
}

if (b == null) {
return hooks;
} else if (hooks.isEmpty()) {
return b;
} else {
HashMap<String, KernelHook<?>> merged = new HashMap<>(hooks.getHooks());
merged.putAll(b.getHooks());
return new KernelHooks(merged);
}
}

/**
* A wrapper for KernelHooks that disables mutating methods.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,9 @@ public static <T> ImplementationFunc<T> getFunction(Method method, Object instan
}

// kernelHooks must be effectively final for lambda
KernelHooks kernelHooks = context.getKernelHooks() != null
? context.getKernelHooks()
: kernel.getGlobalKernelHooks();
assert kernelHooks != null : "getGlobalKernelHooks() should never return null!";
KernelHooks kernelHooks = KernelHooks.merge(
kernel.getGlobalKernelHooks(),
context.getKernelHooks());

FunctionInvokingEvent updatedState = kernelHooks
.executeHooks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ private Flux<FunctionResult<T>> invokeInternalAsync(
: InvocationContext.builder().build();

// must be effectively final for lambda
KernelHooks kernelHooks = context.getKernelHooks() != null
? context.getKernelHooks()
: kernel.getGlobalKernelHooks();
assert kernelHooks != null : "getGlobalKernelHooks() should never return null";
KernelHooks kernelHooks = KernelHooks.merge(
kernel.getGlobalKernelHooks(),
context.getKernelHooks());

PromptRenderingEvent preRenderingHookResult = kernelHooks
.executeHooks(new PromptRenderingEvent(this, argumentsIn));
Expand Down