diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cdbf56b95..04e5431985 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to Chainlit will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [2.8.5] - 2025-11-07 +### Added +- Add display_name to ChatProfile +- Add slack reaction event callback +- Add raw response from OAuth providers + +### Fixed +- Security vulnerability in Chainlint: added missed ACL check for session initialization + +### Changed +- Remove FastAPI version restrictions + ## [2.8.4] - 2025-10-29 ### Added diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index d4672a7e6a..437c33f4ca 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import Any, Dict, Literal, Optional, Tuple, Union +from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union from urllib.parse import unquote from starlette.requests import cookie_parser @@ -18,7 +18,7 @@ from chainlit.logger import logger from chainlit.message import ErrorMessage, Message from chainlit.server import sio -from chainlit.session import WebsocketSession +from chainlit.session import ClientType, WebsocketSession from chainlit.types import ( InputAudioChunk, InputAudioChunkPayload, @@ -29,8 +29,13 @@ WSGIEnvironment: TypeAlias = dict[str, Any] -# Generic error message reused across resume flows. -THREAD_NOT_FOUND_MSG = "Thread not found." + +class WebSocketSessionAuth(TypedDict): + sessionId: str + userEnv: str | None + clientType: ClientType + chatProfile: str | None + threadId: str | None def restore_existing_session(sid, session_id, emit_fn, emit_call_fn): @@ -96,16 +101,15 @@ def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]: return None -def _get_token(environ: WSGIEnvironment, auth: dict) -> Optional[str]: +def _get_token(environ: WSGIEnvironment) -> Optional[str]: """Take WSGI environ, return access token.""" return _get_token_from_cookie(environ) async def _authenticate_connection( - environ, - auth, + environ: WSGIEnvironment, ) -> Union[Tuple[Union[User, PersistedUser], str], Tuple[None, None]]: - if token := _get_token(environ, auth): + if token := _get_token(environ): user = await get_current_user(token=token) if user: return user, token @@ -114,12 +118,14 @@ async def _authenticate_connection( @sio.on("connect") # pyright: ignore [reportOptionalCall] -async def connect(sid, environ, auth): - user = token = None +async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth): + user: User | PersistedUser | None = None + token: str | None = None + thread_id = auth.get("threadId") if require_login(): try: - user, token = await _authenticate_connection(environ, auth) + user, token = await _authenticate_connection(environ) except Exception as e: logger.exception("Exception authenticating connection: %s", e) @@ -127,6 +133,16 @@ async def connect(sid, environ, auth): logger.error("Authentication failed in websocket connect.") raise ConnectionRefusedError("authentication failed") + if thread_id: + data_layer = get_data_layer() + if not data_layer: + logger.error("Data layer is not initialized.") + raise ConnectionRefusedError("data layer not initialized") + + if not (await data_layer.get_thread_author(thread_id) == user.identifier): + logger.error("Authorization for the thread failed.") + raise ConnectionRefusedError("authorization failed") + # Session scoped function to emit to the client def emit_fn(event, data): return sio.emit(event, data, to=sid) @@ -135,14 +151,14 @@ def emit_fn(event, data): def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): return sio.call(event, data, timeout=timeout, to=sid) - session_id = auth.get("sessionId") + session_id = auth["sessionId"] if restore_existing_session(sid, session_id, emit_fn, emit_call_fn): return True user_env_string = auth.get("userEnv") user_env = load_user_env(user_env_string) - client_type = auth.get("clientType") + client_type = auth["clientType"] url_encoded_chat_profile = auth.get("chatProfile") chat_profile = ( unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None @@ -158,7 +174,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): user=user, token=token, chat_profile=chat_profile, - thread_id=auth.get("threadId"), + thread_id=thread_id, environ=environ, ) diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index e2e5552a69..f3c965d247 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -304,6 +304,7 @@ class ChatProfile(DataClassJsonMixin): name: str markdown_description: str icon: Optional[str] = None + display_name: Optional[str] = None default: bool = False starters: Optional[List[Starter]] = None config_overrides: Any = None diff --git a/backend/chainlit/version.py b/backend/chainlit/version.py index e897069b0b..79a53d3a76 100644 --- a/backend/chainlit/version.py +++ b/backend/chainlit/version.py @@ -1 +1 @@ -__version__ = "2.8.4" +__version__ = "2.8.5" diff --git a/cypress/e2e/chat_prefill/main.py b/cypress/e2e/chat_prefill/main.py new file mode 100644 index 0000000000..81460046fc --- /dev/null +++ b/cypress/e2e/chat_prefill/main.py @@ -0,0 +1,6 @@ +import chainlit as cl + + +@cl.on_chat_start +async def main(): + await cl.Message("Hello, this is a test message!").send() diff --git a/cypress/e2e/chat_prefill/spec.cy.ts b/cypress/e2e/chat_prefill/spec.cy.ts new file mode 100644 index 0000000000..406c979062 --- /dev/null +++ b/cypress/e2e/chat_prefill/spec.cy.ts @@ -0,0 +1,34 @@ +describe('Chat Prefill', () => { + it('should display a prefill message when the chat starts', () => { + cy.visit('/?prompt=Hello%20World'); + + cy.get('#chat-input', { timeout: 10000 }) + .should('be.visible') + .and('have.value', 'Hello World'); + }); + + it('should not prefill the chat when prompt is empty', () => { + cy.visit('/'); + + cy.get('#chat-input', { timeout: 10000 }) + .should('be.visible') + .and('have.value', ''); + }); + + it('should correctly prefill with special characters', () => { + const prompt = encodeURIComponent("Hi there! How's it going?"); + cy.visit(`/?prompt=${prompt}`); + + cy.get('#chat-input', { timeout: 10000 }) + .should('be.visible') + .and('have.value', "Hi there! How's it going?"); + }); + + it('should focus the chat input when prefilled', () => { + cy.visit('/?prompt=FocusTest'); + + cy.get('#chat-input', { timeout: 10000 }) + .should('be.visible') + .and('have.focus'); + }); +}); diff --git a/cypress/e2e/thread_resume/main.py b/cypress/e2e/thread_resume/main.py index 7bc2df8513..2f6ab41137 100644 --- a/cypress/e2e/thread_resume/main.py +++ b/cypress/e2e/thread_resume/main.py @@ -3,6 +3,8 @@ import chainlit as cl import chainlit.data as cl_data +from chainlit.element import ElementDict, Element +from chainlit.step import StepDict from chainlit.types import ( ThreadDict, Pagination, @@ -30,6 +32,62 @@ async def create_user(self, user: cl.User): id=user.identifier, createdAt=now, identifier=user.identifier ) + async def delete_feedback( + self, + feedback_id: str, + ) -> bool: + pass + + async def upsert_feedback( + self, + feedback: Feedback, + ) -> str: + pass + + async def create_element(self, element: "Element"): + pass + + async def get_element( + self, thread_id: str, element_id: str + ) -> Optional["ElementDict"]: + pass + + async def delete_element(self, element_id: str, thread_id: Optional[str] = None): + pass + + async def create_step(self, step_dict: "StepDict"): + pass + + async def update_step(self, step_dict: "StepDict"): + pass + + async def delete_step(self, step_id: str): + pass + + async def get_thread_author(self, thread_id: str) -> str: + return (await self.get_thread(thread_id))["userIdentifier"] + + async def delete_thread(self, thread_id: str): + for uid, threads in THREADS.items(): + THREADS[uid] = [t for t in threads if t["id"] != thread_id] + + async def list_threads( + self, pagination: Pagination, filters: ThreadFilter + ) -> PaginatedResponse[ThreadDict]: + user_id = filters.userId or "" + data = THREADS.get(user_id, []) + return PaginatedResponse( + data=data, + pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None), + ) + + async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": + for threads in THREADS.values(): + for t in threads: + if t["id"] == thread_id: + return t + return None + async def update_thread( self, thread_id: str, @@ -57,55 +115,7 @@ async def update_thread( if tags is not None: thr["tags"] = tags - async def list_threads( - self, pagination: Pagination, filters: ThreadFilter - ) -> PaginatedResponse[ThreadDict]: - user_id = filters.userId or "" - data = THREADS.get(user_id, []) - return PaginatedResponse( - data=data, - pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None), - ) - - async def get_thread(self, thread_id: str): - for threads in THREADS.values(): - for t in threads: - if t["id"] == thread_id: - return t - return None - - async def delete_thread(self, thread_id: str): - for uid, threads in THREADS.items(): - THREADS[uid] = [t for t in threads if t["id"] != thread_id] - - async def upsert_feedback(self, feedback: Feedback) -> str: - return "" - - async def build_debug_url(self): - pass - - async def create_element(self): - pass - - async def create_step(self): - pass - - async def delete_element(self): - pass - - async def delete_feedback(self): - pass - - async def delete_step(self): - pass - - async def get_element(self): - pass - - async def get_thread_author(self): - pass - - async def update_step(self): + async def build_debug_url(self) -> str: pass async def close(self) -> None: diff --git a/frontend/src/components/Tasklist/Task.tsx b/frontend/src/components/Tasklist/Task.tsx index 00f5804008..b82bb2bb3d 100644 --- a/frontend/src/components/Tasklist/Task.tsx +++ b/frontend/src/components/Tasklist/Task.tsx @@ -1,3 +1,5 @@ +import { Markdown } from '@/components/Markdown'; + import { TaskStatusIcon } from './TaskStatusIcon'; export interface ITask { @@ -14,9 +16,11 @@ export interface ITaskList { interface TaskProps { index: number; task: ITask; + allowHtml?: boolean; + latex?: boolean; } -export const Task = ({ index, task }: TaskProps) => { +export const Task = ({ index, task, allowHtml, latex }: TaskProps) => { const statusStyles = { ready: '', running: 'font-semibold', @@ -48,14 +52,26 @@ export const Task = ({ index, task }: TaskProps) => { return (