Skip to content
19 changes: 8 additions & 11 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def _authenticate_connection(
async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth):
user: User | PersistedUser | None = None
token: str | None = None
thread_id = auth.get("threadId")
thread_id = auth.get("threadId", None)

if require_login():
try:
Expand All @@ -134,14 +134,11 @@ async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth
raise ConnectionRefusedError("authentication failed")

if thread_id:
data_layer = get_data_layer()
if not data_layer:
logger.error("Data layer is not initialized.")
raise ConnectionRefusedError("data layer not initialized")

if not (await data_layer.get_thread_author(thread_id) == user.identifier):
logger.error("Authorization for the thread failed.")
raise ConnectionRefusedError("authorization failed")
if data_layer := get_data_layer():
thread = await data_layer.get_thread(thread_id)
if thread and not (thread["userIdentifier"] == user.identifier):
logger.error("Authorization for the thread failed.")
raise ConnectionRefusedError("authorization failed")

# Session scoped function to emit to the client
def emit_fn(event, data):
Expand All @@ -155,11 +152,11 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
return True

user_env_string = auth.get("userEnv")
user_env_string = auth.get("userEnv", None)
user_env = load_user_env(user_env_string)

client_type = auth["clientType"]
url_encoded_chat_profile = auth.get("chatProfile")
url_encoded_chat_profile = auth.get("chatProfile", None)
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
)
Expand Down
52 changes: 52 additions & 0 deletions cypress/e2e/auth/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from uuid import uuid4

import chainlit as cl
from chainlit.auth import create_jwt
from chainlit.server import _authenticate_user, app
from chainlit.user import User
from fastapi import Request, Response

os.environ["CHAINLIT_AUTH_SECRET"] = "SUPER_SECRET" # nosec B105
os.environ["CHAINLIT_CUSTOM_AUTH"] = "true"


@app.get("/auth/custom")
async def custom_auth(request: Request) -> Response:
user_id = str(uuid4())

user = User(identifier=user_id, metadata={"role": "user"})
response = await _authenticate_user(request, user)

return response


@app.get("/auth/token")
async def custom_token_auth() -> Response:
user_id = str(uuid4())

user = User(identifier=user_id, metadata={"role": "admin"})
response = create_jwt(user)

return response


catch_all_route = None
for route in app.routes:
if route.path == "/{full_path:path}":
catch_all_route = route

if catch_all_route:
app.routes.remove(catch_all_route)
app.routes.append(catch_all_route)


@cl.on_chat_start
async def on_chat_start():
user = cl.user_session.get("user")
await cl.Message(f"Hello {user.identifier}").send()


@cl.on_message
async def on_message(msg: cl.Message):
await cl.Message(content=f"Echo: {msg.content}").send()
130 changes: 130 additions & 0 deletions cypress/e2e/auth/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import { loadCopilotScript, mountCopilotWidget, openCopilot, submitMessage } from '../../support/testUtils';

function login() {
return cy.request({
method: 'GET',
url: '/auth/custom',
followRedirect: false
})
}

function getToken() {
return cy.request({
method: 'GET',
url: '/auth/token',
followRedirect: false
})
}

function shouldShowGreetingMessage() {
it('should show greeting message', () => {
cy.get('.step').should('exist');
cy.get('.step').should('contain', 'Hello');
});
}

function shouldSendMessageAndRecieveAnswer() {
it('should send message and receive answer', () => {
cy.get('.step').should('contain', 'Hello');

const testMessage = 'Test message from custom auth';
submitMessage(testMessage);

cy.get('.step').should('contain', 'Echo:');
cy.get('.step').should('contain', testMessage);
});

}

describe('Custom Auth', () => {
describe('when unauthenticated', () => {
beforeEach(() => {
cy.intercept('GET', '/user').as('user');
});

it('should attempt to and not have permission to access /user', () => {
cy.wait('@user').then((interception) => {
expect(interception.response.statusCode).to.equal(401);
});
});

it('should redirect to login dialog', () => {
cy.location('pathname').should('eq', '/login');
});
});

describe('authenticating via custom endpoint', () => {
beforeEach(() => {
login().then((response) => {
expect(response.status).to.equal(200);
// Verify cookie is set in response headers
expect(response.headers).to.have.property('set-cookie');
const cookies = Array.isArray(response.headers['set-cookie'])
? response.headers['set-cookie']
: [response.headers['set-cookie']];
expect(cookies[0]).to.contain('access_token');
});
});

const shouldBeLoggedIn = () => {
it('should not be on /login', () => {
cy.location('pathname').should('not.contain', '/login');
});

shouldShowGreetingMessage();
};

shouldBeLoggedIn();

it('should request and have access to /user', () => {
cy.intercept('GET', '/user').as('user');
cy.wait('@user').then((interception) => {
expect(interception.response.statusCode).to.equal(200);
});
});

shouldSendMessageAndRecieveAnswer();

describe('after reloading', () => {
beforeEach(() => {
cy.reload();
});

shouldBeLoggedIn();
});
});
});

describe('Copilot Token', { includeShadowDom: true }, () => {
beforeEach(() => {
cy.location('pathname').should('eq', '/login');

loadCopilotScript();
});

describe('when unauthenticated', () => {
it('should throw error about missing authentication token', () => {
mountCopilotWidget();
openCopilot();
cy.get('#chainlit-copilot-chat').should('contain', 'No authentication token provided.');
});
});

describe('authenticating via custom endpoint', () => {
beforeEach(() => {
getToken().then((response) => {
expect(response.status).to.equal(200);

const accessToken = response.body
expect(accessToken).to.not.be.null;

mountCopilotWidget({ accessToken });
openCopilot();
});
})

shouldShowGreetingMessage();

shouldSendMessageAndRecieveAnswer();
});
});
Loading
Loading