Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ All notable changes to Chainlit will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [2.8.5] - 2025-11-07
### Added
- Add display_name to ChatProfile
- Add slack reaction event callback
- Add raw response from OAuth providers

### Fixed
- Security vulnerability in Chainlint: added missed ACL check for session initialization

### Changed
- Remove FastAPI version restrictions

## [2.8.4] - 2025-10-29

### Added
Expand Down
44 changes: 30 additions & 14 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import Any, Dict, Literal, Optional, Tuple, Union
from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union
from urllib.parse import unquote

from starlette.requests import cookie_parser
Expand All @@ -18,7 +18,7 @@
from chainlit.logger import logger
from chainlit.message import ErrorMessage, Message
from chainlit.server import sio
from chainlit.session import WebsocketSession
from chainlit.session import ClientType, WebsocketSession
from chainlit.types import (
InputAudioChunk,
InputAudioChunkPayload,
Expand All @@ -29,8 +29,13 @@

WSGIEnvironment: TypeAlias = dict[str, Any]

# Generic error message reused across resume flows.
THREAD_NOT_FOUND_MSG = "Thread not found."

class WebSocketSessionAuth(TypedDict):
sessionId: str
userEnv: str | None
clientType: ClientType
chatProfile: str | None
threadId: str | None


def restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
Expand Down Expand Up @@ -96,16 +101,15 @@ def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]:
return None


def _get_token(environ: WSGIEnvironment, auth: dict) -> Optional[str]:
def _get_token(environ: WSGIEnvironment) -> Optional[str]:
"""Take WSGI environ, return access token."""
return _get_token_from_cookie(environ)


async def _authenticate_connection(
environ,
auth,
environ: WSGIEnvironment,
) -> Union[Tuple[Union[User, PersistedUser], str], Tuple[None, None]]:
if token := _get_token(environ, auth):
if token := _get_token(environ):
user = await get_current_user(token=token)
if user:
return user, token
Expand All @@ -114,19 +118,31 @@ async def _authenticate_connection(


@sio.on("connect") # pyright: ignore [reportOptionalCall]
async def connect(sid, environ, auth):
user = token = None
async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth):
user: User | PersistedUser | None = None
token: str | None = None
thread_id = auth.get("threadId")

if require_login():
try:
user, token = await _authenticate_connection(environ, auth)
user, token = await _authenticate_connection(environ)
except Exception as e:
logger.exception("Exception authenticating connection: %s", e)

if not user:
logger.error("Authentication failed in websocket connect.")
raise ConnectionRefusedError("authentication failed")

if thread_id:
data_layer = get_data_layer()
if not data_layer:
logger.error("Data layer is not initialized.")
raise ConnectionRefusedError("data layer not initialized")

if not (await data_layer.get_thread_author(thread_id) == user.identifier):
logger.error("Authorization for the thread failed.")
raise ConnectionRefusedError("authorization failed")

# Session scoped function to emit to the client
def emit_fn(event, data):
return sio.emit(event, data, to=sid)
Expand All @@ -135,14 +151,14 @@ def emit_fn(event, data):
def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
return sio.call(event, data, timeout=timeout, to=sid)

session_id = auth.get("sessionId")
session_id = auth["sessionId"]
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
return True

user_env_string = auth.get("userEnv")
user_env = load_user_env(user_env_string)

client_type = auth.get("clientType")
client_type = auth["clientType"]
url_encoded_chat_profile = auth.get("chatProfile")
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
Expand All @@ -158,7 +174,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
user=user,
token=token,
chat_profile=chat_profile,
thread_id=auth.get("threadId"),
thread_id=thread_id,
environ=environ,
)

Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class ChatProfile(DataClassJsonMixin):
name: str
markdown_description: str
icon: Optional[str] = None
display_name: Optional[str] = None
default: bool = False
starters: Optional[List[Starter]] = None
config_overrides: Any = None
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.8.4"
__version__ = "2.8.5"
6 changes: 6 additions & 0 deletions cypress/e2e/chat_prefill/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import chainlit as cl


@cl.on_chat_start
async def main():
await cl.Message("Hello, this is a test message!").send()
34 changes: 34 additions & 0 deletions cypress/e2e/chat_prefill/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
describe('Chat Prefill', () => {
it('should display a prefill message when the chat starts', () => {
cy.visit('/?prompt=Hello%20World');

cy.get('#chat-input', { timeout: 10000 })
.should('be.visible')
.and('have.value', 'Hello World');
});

it('should not prefill the chat when prompt is empty', () => {
cy.visit('/');

cy.get('#chat-input', { timeout: 10000 })
.should('be.visible')
.and('have.value', '');
});

it('should correctly prefill with special characters', () => {
const prompt = encodeURIComponent("Hi there! How's it going?");
cy.visit(`/?prompt=${prompt}`);

cy.get('#chat-input', { timeout: 10000 })
.should('be.visible')
.and('have.value', "Hi there! How's it going?");
});

it('should focus the chat input when prefilled', () => {
cy.visit('/?prompt=FocusTest');

cy.get('#chat-input', { timeout: 10000 })
.should('be.visible')
.and('have.focus');
});
});
108 changes: 59 additions & 49 deletions cypress/e2e/thread_resume/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import chainlit as cl
import chainlit.data as cl_data
from chainlit.element import ElementDict, Element
from chainlit.step import StepDict
from chainlit.types import (
ThreadDict,
Pagination,
Expand Down Expand Up @@ -30,6 +32,62 @@ async def create_user(self, user: cl.User):
id=user.identifier, createdAt=now, identifier=user.identifier
)

async def delete_feedback(
self,
feedback_id: str,
) -> bool:
pass

async def upsert_feedback(
self,
feedback: Feedback,
) -> str:
pass

async def create_element(self, element: "Element"):
pass

async def get_element(
self, thread_id: str, element_id: str
) -> Optional["ElementDict"]:
pass

async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
pass

async def create_step(self, step_dict: "StepDict"):
pass

async def update_step(self, step_dict: "StepDict"):
pass

async def delete_step(self, step_id: str):
pass

async def get_thread_author(self, thread_id: str) -> str:
return (await self.get_thread(thread_id))["userIdentifier"]

async def delete_thread(self, thread_id: str):
for uid, threads in THREADS.items():
THREADS[uid] = [t for t in threads if t["id"] != thread_id]

async def list_threads(
self, pagination: Pagination, filters: ThreadFilter
) -> PaginatedResponse[ThreadDict]:
user_id = filters.userId or ""
data = THREADS.get(user_id, [])
return PaginatedResponse(
data=data,
pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None),
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
for threads in THREADS.values():
for t in threads:
if t["id"] == thread_id:
return t
return None

async def update_thread(
self,
thread_id: str,
Expand Down Expand Up @@ -57,55 +115,7 @@ async def update_thread(
if tags is not None:
thr["tags"] = tags

async def list_threads(
self, pagination: Pagination, filters: ThreadFilter
) -> PaginatedResponse[ThreadDict]:
user_id = filters.userId or ""
data = THREADS.get(user_id, [])
return PaginatedResponse(
data=data,
pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None),
)

async def get_thread(self, thread_id: str):
for threads in THREADS.values():
for t in threads:
if t["id"] == thread_id:
return t
return None

async def delete_thread(self, thread_id: str):
for uid, threads in THREADS.items():
THREADS[uid] = [t for t in threads if t["id"] != thread_id]

async def upsert_feedback(self, feedback: Feedback) -> str:
return ""

async def build_debug_url(self):
pass

async def create_element(self):
pass

async def create_step(self):
pass

async def delete_element(self):
pass

async def delete_feedback(self):
pass

async def delete_step(self):
pass

async def get_element(self):
pass

async def get_thread_author(self):
pass

async def update_step(self):
async def build_debug_url(self) -> str:
pass

async def close(self) -> None:
Expand Down
26 changes: 21 additions & 5 deletions frontend/src/components/Tasklist/Task.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { Markdown } from '@/components/Markdown';

import { TaskStatusIcon } from './TaskStatusIcon';

export interface ITask {
Expand All @@ -14,9 +16,11 @@ export interface ITaskList {
interface TaskProps {
index: number;
task: ITask;
allowHtml?: boolean;
latex?: boolean;
}

export const Task = ({ index, task }: TaskProps) => {
export const Task = ({ index, task, allowHtml, latex }: TaskProps) => {
const statusStyles = {
ready: '',
running: 'font-semibold',
Expand Down Expand Up @@ -48,14 +52,26 @@ export const Task = ({ index, task }: TaskProps) => {
return (
<div className={`task task-status-${task.status}`}>
<div
className={`w-full flex font-medium py-2 text-sm leading-snug ${
className={`w-full grid grid-cols-[auto_auto_1fr] items-start gap-1.5 font-medium py-0.5 px-1 text-sm leading-tight ${
statusStyles[task.status]
} ${task.forId ? 'cursor-pointer' : 'cursor-default'}`}
onClick={handleClick}
>
<span className="flex-none w-8 pr-2">{index}</span>
<TaskStatusIcon status={task.status} />
<span className="pl-2">{task.title}</span>
<div className="text-xs text-muted-foreground text-right pr-1 pt-[1px]">
{index}
</div>
<div className="flex items-start pt-[1px]">
<TaskStatusIcon status={task.status} />
</div>
<div className="min-w-0">
<Markdown
allowHtml={allowHtml}
latex={latex}
className="max-w-none prose-sm text-left break-words [&_p]:m-0 [&_p]:leading-snug [&_div]:leading-snug [&_div]:mt-0 [&_strong]:font-semibold"
>
{task.title}
</Markdown>
</div>
</div>
</div>
);
Expand Down
Loading
Loading