diff --git a/.gitignore b/.gitignore index 9b0ebaef6b5..12b1ce7f9fa 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ __debug_bin* demo/output/* coverage.out +cmd/copilot/main diff --git a/cmd/copilot/main.go b/cmd/copilot/main.go new file mode 100644 index 00000000000..1f3b577b31a --- /dev/null +++ b/cmd/copilot/main.go @@ -0,0 +1,499 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/sanity-io/litter" + "github.com/zalando/go-keyring" +) + +const ( + KEYRING_SERVICE = "lazygit" + KEYRING_USER = "github-copilot" +) + +// ICopilotChat defines the interface for chat operations +type ICopilotChat interface { + Authenticate() error + IsAuthenticated() bool + Chat(request Request) (string, error) +} + +var _ ICopilotChat = &CopilotChat{} + +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleSystem Role = "system" +) + +type Model string + +const ( + Gpt4o Model = "gpt-4o-2024-05-13" + Gpt4 Model = "gpt-4" + Gpt3_5Turbo Model = "gpt-3.5-turbo" + O1Preview Model = "o1-preview-2024-09-12" + O1Mini Model = "o1-mini-2024-09-12" + Claude3_5Sonnet Model = "claude-3.5-sonnet" +) + +const ( + COPILOT_CHAT_COMPLETION_URL = "https://api.githubcopilot.com/chat/completions" + COPILOT_CHAT_AUTH_URL = "https://api.github.com/copilot_internal/v2/token" + EDITOR_VERSION = "Lazygit/0.44.0" + COPILOT_INTEGRATION_ID = "vscode-chat" +) +const ( + CACHE_FILE_NAME = ".copilot_auth.json" +) +const ( + CHECK_INTERVAL = 30 * time.Second + MAX_AUTH_TIME = 5 * time.Minute +) +const ( + GITHUB_CLIENT_ID = "Iv1.b507a08c87ecfe98" +) + +type ChatMessage struct { + Role Role `json:"role"` + Content string `json:"content"` +} + +type Request struct { + Intent bool `json:"intent"` + N int `json:"n"` + Stream bool `json:"stream"` + Temperature float32 `json:"temperature"` + Model Model `json:"model"` + Messages []ChatMessage `json:"messages"` +} + +type ContentFilterResult struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity"` +} + +type ContentFilterResults struct { + Hate ContentFilterResult `json:"hate"` + SelfHarm ContentFilterResult `json:"self_harm"` + Sexual ContentFilterResult `json:"sexual"` + Violence ContentFilterResult `json:"violence"` +} + +type ChatResponse struct { + Choices []ResponseChoice `json:"choices"` + Created int64 `json:"created"` + ID string `json:"id"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results"` + Usage Usage `json:"usage"` +} + +type ResponseChoice struct { + ContentFilterResults ContentFilterResults `json:"content_filter_results"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatMessage `json:"message"` +} + +type PromptFilterResult struct { + ContentFilterResults ContentFilterResults `json:"content_filter_results"` + PromptIndex int `json:"prompt_index"` +} + +type Usage struct { + CompletionTokens int `json:"completion_tokens"` + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type ApiTokenResponse struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` +} + +type ApiToken struct { + ApiKey string + ExpiresAt time.Time +} +type CacheData struct { + OAuthToken string `json:"oauth_token"` + ApiKey string `json:"api_key"` + ExpiresAt time.Time `json:"expires_at"` +} +type DeviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationUri string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type DeviceTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error,omitempty"` +} + +type CopilotChat struct { + OAuthToken string + ApiToken *ApiToken + Client *http.Client + mu sync.Mutex +} + +// TODO: import a library to count the number of tokens in a string +func (m Model) MaxTokenCount() int { + switch m { + case Gpt4o: + return 64000 + case Gpt4: + return 32768 + case Gpt3_5Turbo: + return 12288 + case O1Mini: + return 20000 + case O1Preview: + return 20000 + case Claude3_5Sonnet: + return 200000 + default: + return 0 + } +} + +func NewCopilotChat(client *http.Client) *CopilotChat { + if client == nil { + client = &http.Client{} + } + + chat := &CopilotChat{ + Client: client, + } + + if err := chat.loadFromKeyring(); err != nil { + log.Printf("Warning: Failed to load from keyring: %v", err) + } + + return chat +} + +func (self *CopilotChat) saveToKeyring() error { + data := CacheData{ + OAuthToken: self.OAuthToken, + ApiKey: self.ApiToken.ApiKey, + ExpiresAt: self.ApiToken.ExpiresAt, + } + + fileData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal keyring data: %v", err) + } + + if err := keyring.Set(KEYRING_SERVICE, KEYRING_USER, string(fileData)); err != nil { + return fmt.Errorf("failed to save to keyring: %v", err) + } + + return nil +} + +func (self *CopilotChat) loadFromKeyring() error { + jsonData, err := keyring.Get(KEYRING_SERVICE, KEYRING_USER) + if err != nil { + if err == keyring.ErrNotFound { + return nil // No credentials stored yet + } + return fmt.Errorf("failed to get credentials from keyring: %v", err) + } + + var data CacheData + if err := json.Unmarshal([]byte(jsonData), &data); err != nil { + return fmt.Errorf("failed to unmarshal Keyring data: %v", err) + } + + // Always load OAuth token if it exists + if data.OAuthToken != "" { + self.OAuthToken = data.OAuthToken + } + + // If we have a valid API key, use it + if data.ApiKey != "" && data.ExpiresAt.After(time.Now()) { + self.ApiToken = &ApiToken{ + ApiKey: data.ApiKey, + ExpiresAt: data.ExpiresAt, + } + fmt.Println("Loaded valid API key from keyring") + return nil + } + + // If we have OAuth token but no valid API key, fetch a new one + if self.OAuthToken != "" { + fmt.Println("OAuth token found, fetching new API key...") + if err := self.fetchNewApiToken(); err != nil { + return fmt.Errorf("failed to fetch new API token: %v", err) + } + fmt.Println("Successfully fetched new API key") + return nil + } + + return nil +} + +func (self *CopilotChat) fetchNewApiToken() error { + apiTokenReq, err := http.NewRequest(http.MethodGet, COPILOT_CHAT_AUTH_URL, nil) + if err != nil { + return fmt.Errorf("failed to create API token request: %v", err) + } + + apiTokenReq.Header.Set("Authorization", fmt.Sprintf("token %s", self.OAuthToken)) + setHeaders(apiTokenReq, "") + + apiTokenResp, err := self.Client.Do(apiTokenReq) + if err != nil { + return fmt.Errorf("failed to get API token: %v", err) + } + defer apiTokenResp.Body.Close() + + var apiTokenResponse ApiTokenResponse + if err := json.NewDecoder(apiTokenResp.Body).Decode(&apiTokenResponse); err != nil { + return fmt.Errorf("failed to decode API token response: %v", err) + } + + self.ApiToken = &ApiToken{ + ApiKey: apiTokenResponse.Token, + ExpiresAt: time.Unix(apiTokenResponse.ExpiresAt, 0), + } + + return self.saveToKeyring() +} + +func setHeaders(req *http.Request, contentType string) { + req.Header.Set("Accept", "application/json") + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + req.Header.Set("Editor-Version", EDITOR_VERSION) + req.Header.Set("Copilot-Integration-Id", COPILOT_INTEGRATION_ID) +} + +func (self *CopilotChat) Authenticate() error { + // Try to load from keyring first + if err := self.loadFromKeyring(); err == nil && self.IsAuthenticated() { + return nil + } + + self.mu.Lock() + defer self.mu.Unlock() + + // Step 1: Request device and user codes + deviceCodeReq, err := http.NewRequest( + http.MethodPost, + "https://github.com/login/device/code", + strings.NewReader(fmt.Sprintf( + "client_id=%s&scope=copilot", + GITHUB_CLIENT_ID, + )), + ) + if err != nil { + return fmt.Errorf("failed to create device code request: %v", err) + } + deviceCodeReq.Header.Set("Accept", "application/json") + deviceCodeReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := self.Client.Do(deviceCodeReq) + if err != nil { + return fmt.Errorf("failed to get device code: %v", err) + } + defer resp.Body.Close() + + var deviceCode DeviceCodeResponse + if err := json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil { + return fmt.Errorf("failed to decode device code response: %v", err) + } + + // Step 2: Display user code and verification URL + fmt.Printf("\nPlease visit: %s\n", deviceCode.VerificationUri) + fmt.Printf("And enter code: %s\n\n", deviceCode.UserCode) + + // Step 3: Poll for the access token with timeout + startTime := time.Now() + attempts := 0 + + // FIXME: There is probably a better way to do this + for { + if time.Since(startTime) >= MAX_AUTH_TIME { + return fmt.Errorf("authentication timed out after 5 minutes") + } + + time.Sleep(CHECK_INTERVAL) + attempts++ + fmt.Printf("Checking for authentication... attempt %d\n", attempts) + + tokenReq, err := http.NewRequest(http.MethodPost, "https://github.com/login/oauth/access_token", + strings.NewReader(fmt.Sprintf( + "client_id=%s&device_code=%s&grant_type=urn:ietf:params:oauth:grant-type:device_code", GITHUB_CLIENT_ID, + deviceCode.DeviceCode))) + if err != nil { + return fmt.Errorf("failed to create token request: %v", err) + } + tokenReq.Header.Set("Accept", "application/json") + tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + tokenResp, err := self.Client.Do(tokenReq) + if err != nil { + return fmt.Errorf("failed to get access token: %v", err) + } + + var tokenResponse DeviceTokenResponse + if err := json.NewDecoder(tokenResp.Body).Decode(&tokenResponse); err != nil { + tokenResp.Body.Close() + return fmt.Errorf("failed to decode token response: %v", err) + } + tokenResp.Body.Close() + + if tokenResponse.Error == "authorization_pending" { + fmt.Println("Login not detected. Please visit the URL and enter the code.") + continue + } + if tokenResponse.Error != "" { + if time.Since(startTime) >= MAX_AUTH_TIME { + return fmt.Errorf("authentication timed out after 5 minutes") + } + continue + } + + // Successfully got the access token + self.OAuthToken = tokenResponse.AccessToken + + // Now get the Copilot API token using fetchNewApiToken + if err := self.fetchNewApiToken(); err != nil { + return fmt.Errorf("failed to fetch API token: %v", err) + } + + fmt.Println("Successfully authenticated!") + // Save the new credentials to cache + if err := self.saveToKeyring(); err != nil { + log.Printf("Warning: Failed to save credentials to keyring: %v", err) + } + return nil + } +} + +func (self *CopilotChat) IsAuthenticated() bool { + if self.ApiToken == nil { + return false + } + return self.ApiToken.ExpiresAt.After(time.Now()) +} + +func (self *CopilotChat) Chat(request Request) (string, error) { + fmt.Println("Chatting with Copilot...") + + if !self.IsAuthenticated() { + fmt.Println("Not authenticated with Copilot. Authenticating...") + if err := self.Authenticate(); err != nil { + return "", fmt.Errorf("authentication failed: %v", err) + } + } + + apiKey := self.ApiToken.ApiKey + fmt.Println("Authenticated with Copilot!") + fmt.Println("API Key: ", apiKey) + + litter.Dump(self) + + requestBody, err := json.Marshal(request) + if err != nil { + return "", err + } + fmt.Println("Mounting request body: ", string(requestBody)) + + self.mu.Lock() + defer self.mu.Unlock() + + req, err := http.NewRequest(http.MethodPost, COPILOT_CHAT_COMPLETION_URL, strings.NewReader(string(requestBody))) + if err != nil { + return "", err + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + setHeaders(req, "") + + response, err := self.Client.Do(req) + if err != nil { + return "", err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(response.Body) + return "", fmt.Errorf("failed to get completion: %s", string(body)) + } + + var chatResponse ChatResponse + decoder := json.NewDecoder(response.Body) + if err := decoder.Decode(&chatResponse); err != nil { + return "", fmt.Errorf("failed to decode response: %v", err) + } + + if len(chatResponse.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + return chatResponse.Choices[0].Message.Content, nil +} + +func main() { + client := &http.Client{} + fmt.Println("Starting...") + copilotChat := NewCopilotChat(client) + + fmt.Println("Chatting...") + + err := copilotChat.Authenticate() + if err != nil { + if strings.Contains(err.Error(), "timed out") { + log.Fatalf("Authentication process timed out. Please try again later.") + } + log.Fatalf("Error during authentication: %v", err) + } + + fmt.Println("Authenticated!") + + messages := []ChatMessage{ + { + Role: RoleUser, + Content: "Describe what is Lazygit in one sentence", + }, + } + + request := Request{ + Intent: true, + N: 1, + Stream: false, + Temperature: 0.1, + Model: Gpt4o, + Messages: messages, + } + + response, err := copilotChat.Chat(request) + if err != nil { + log.Fatalf("Error during chat: %v", err) + } + + fmt.Println(response) +} diff --git a/pkg/config/user_config.go b/pkg/config/user_config.go index 0fa706bb242..ae77d7b9411 100644 --- a/pkg/config/user_config.go +++ b/pkg/config/user_config.go @@ -38,6 +38,8 @@ type UserConfig struct { PromptToReturnFromSubprocess bool `yaml:"promptToReturnFromSubprocess"` // Keybindings Keybinding KeybindingConfig `yaml:"keybinding"` + // AI-powered commit message generation + AI AIConfig `yaml:"ai"` } type RefresherConfig struct { @@ -575,6 +577,23 @@ type KeybindingCommitMessageConfig struct { CommitMenu string `yaml:"commitMenu"` } +type AIConfig struct { + // If true, AI-powered commit message generation is enabled + Enabled bool `yaml:"enabled"` + // AI provider to use: "openai", "github", "anthropic", etc. + Provider string `yaml:"provider"` + // Model name to use: "gpt-4", "copilot", etc. + Model string `yaml:"model"` + // API key for the AI provider + APIKey string `yaml:"apiKey"` + // Base URL for custom API endpoints + BaseURL string `yaml:"baseUrl"` + // Maximum number of tokens to generate + MaxTokens int `yaml:"maxTokens"` + // Temperature for response randomness (0.0 to 1.0) + Temperature float64 `yaml:"temperature"` +} + // OSConfig contains config on the level of the os type OSConfig struct { // Command for editing a file. Should contain "{{filename}}". @@ -1029,5 +1048,14 @@ func GetDefaultConfig() *UserConfig { CommitMenu: "", }, }, + AI: AIConfig{ + Enabled: false, + Provider: "openai", + Model: "gpt-4", + APIKey: "", + BaseURL: "https://api.openai.com/v1", + MaxTokens: 150, + Temperature: 0.7, + }, } } diff --git a/pkg/gui/services/ai/context_builder.go b/pkg/gui/services/ai/context_builder.go new file mode 100644 index 00000000000..8b38fb6b853 --- /dev/null +++ b/pkg/gui/services/ai/context_builder.go @@ -0,0 +1,162 @@ +package ai + +import ( + "path/filepath" + "strings" + + "github.com/jesseduffield/lazygit/pkg/gui/controllers/helpers" +) + +// ContextBuilder builds AI context from git repository state +type ContextBuilder struct { + c *helpers.HelperCommon +} + +// NewContextBuilder creates a new context builder +func NewContextBuilder(c *helpers.HelperCommon) *ContextBuilder { + return &ContextBuilder{c: c} +} + +// BuildContext builds a GenerateRequest from current git state +func (cb *ContextBuilder) BuildContext(commitType string, existingMessage string) (*GenerateRequest, error) { + // TODO: Implement context building + // 1. Get staged diff + // 2. Analyze file changes + // 3. Get branch name + // 4. Get recent commits for style reference + // 5. Detect project characteristics + + request := &GenerateRequest{ + CommitType: commitType, + ExistingMessage: existingMessage, + } + + // Get staged diff + if err := cb.setStagedDiff(request); err != nil { + return nil, err + } + + // Analyze file changes + if err := cb.setFileChanges(request); err != nil { + return nil, err + } + + // Set git context + cb.setGitContext(request) + + return request, nil +} + +// setStagedDiff gets the staged diff and sets it in the request +func (cb *ContextBuilder) setStagedDiff(request *GenerateRequest) error { + // TODO: Execute `git diff --staged --no-color` and set StagedDiff + // Handle case where there are no staged changes + request.StagedDiff = "" + return nil +} + +// setFileChanges analyzes the staged changes and extracts file information +func (cb *ContextBuilder) setFileChanges(request *GenerateRequest) error { + // TODO: Parse git diff to extract file changes + // 1. Get list of changed files with status + // 2. Detect programming languages + // 3. Count lines added/deleted + // 4. Detect binary files + + request.FileChanges = []FileChange{} + return nil +} + +// setGitContext sets git repository context information +func (cb *ContextBuilder) setGitContext(request *GenerateRequest) { + // TODO: Get current branch name + request.BranchName = cb.getBranchName() + + // TODO: Get project name from repository + request.ProjectName = cb.getProjectName() + + // TODO: Get recent commits for style analysis + request.RecentCommits = cb.getRecentCommits() +} + +// getBranchName returns the current git branch name +func (cb *ContextBuilder) getBranchName() string { + // TODO: Execute `git branch --show-current` or equivalent + return "" +} + +// getProjectName extracts the project name from repository path or remote +func (cb *ContextBuilder) getProjectName() string { + // TODO: Get project name from: + // 1. Repository directory name + // 2. Remote origin URL + // 3. package.json, go.mod, etc. + return "" +} + +// getRecentCommits gets recent commit messages for style analysis +func (cb *ContextBuilder) getRecentCommits() []string { + // TODO: Execute `git log --oneline -10` to get recent commits + // Filter out merge commits and format consistently + return []string{} +} + +// detectLanguage detects the programming language from file extension +func (cb *ContextBuilder) detectLanguage(filePath string) string { + ext := strings.ToLower(filepath.Ext(filePath)) + + languageMap := map[string]string{ + ".go": "Go", + ".js": "JavaScript", + ".ts": "TypeScript", + ".py": "Python", + ".java": "Java", + ".cpp": "C++", + ".c": "C", + ".rs": "Rust", + ".rb": "Ruby", + ".php": "PHP", + ".cs": "C#", + ".swift": "Swift", + ".kt": "Kotlin", + ".scala": "Scala", + ".sh": "Shell", + ".yml": "YAML", + ".yaml": "YAML", + ".json": "JSON", + ".xml": "XML", + ".md": "Markdown", + ".sql": "SQL", + ".dockerfile": "Docker", + } + + if lang, exists := languageMap[ext]; exists { + return lang + } + + // Check for special files + filename := strings.ToLower(filepath.Base(filePath)) + switch filename { + case "dockerfile": + return "Docker" + case "makefile": + return "Make" + case "rakefile": + return "Ruby" + default: + return "Text" + } +} + +// BuildPrompt builds the AI prompt from the request context +func (cb *ContextBuilder) BuildPrompt(request *GenerateRequest) string { + // TODO: Build a well-structured prompt for the AI + // Include: + // 1. Task description + // 2. Context about the changes + // 3. Style guidelines + // 4. Examples from recent commits + // 5. Specific requirements (conventional commits, etc.) + + return "" +} diff --git a/pkg/gui/services/ai/message_validator.go b/pkg/gui/services/ai/message_validator.go new file mode 100644 index 00000000000..078bfb3ac74 --- /dev/null +++ b/pkg/gui/services/ai/message_validator.go @@ -0,0 +1,191 @@ +package ai + +import ( + "strings" + "unicode/utf8" +) + +// MessageValidator validates AI-generated commit messages +type MessageValidator struct { + maxSubjectLength int + maxBodyLength int + requireSubject bool +} + +// NewMessageValidator creates a new message validator with default settings +func NewMessageValidator() *MessageValidator { + return &MessageValidator{ + maxSubjectLength: 72, // Standard git convention + maxBodyLength: 72, // Per line in body + requireSubject: true, + } +} + +// ValidateMessage validates a generated commit message +func (mv *MessageValidator) ValidateMessage(response *GenerateResponse) error { + if response == nil { + return ErrInvalidResponse + } + + // Check if message is empty + if strings.TrimSpace(response.Message) == "" { + return ErrInvalidResponse + } + + // Validate subject line + if err := mv.validateSubject(response.Message); err != nil { + return err + } + + // Validate message length + if err := mv.validateLength(response.Message); err != nil { + return err + } + + // Check for inappropriate content + if err := mv.validateContent(response.Message); err != nil { + return err + } + + return nil +} + +// validateSubject validates the commit message subject line +func (mv *MessageValidator) validateSubject(message string) error { + lines := strings.Split(message, "\n") + if len(lines) == 0 { + return ErrInvalidResponse + } + + subject := strings.TrimSpace(lines[0]) + + // Check if subject is required and present + if mv.requireSubject && subject == "" { + return ErrInvalidResponse + } + + // Check subject length + if utf8.RuneCountInString(subject) > mv.maxSubjectLength { + return ErrMessageTooLong + } + + // TODO: Add more subject validation rules: + // - No trailing period + // - Capitalized first letter + // - Imperative mood check + // - Conventional commit format validation + + return nil +} + +// validateLength validates the overall message length +func (mv *MessageValidator) validateLength(message string) error { + lines := strings.Split(message, "\n") + + // Check body line lengths (skip empty lines and subject) + for i, line := range lines { + if i == 0 || strings.TrimSpace(line) == "" { + continue // Skip subject line and empty lines + } + + if utf8.RuneCountInString(line) > mv.maxBodyLength { + return ErrMessageTooLong + } + } + + return nil +} + +// validateContent checks for inappropriate content in the message +func (mv *MessageValidator) validateContent(message string) error { + // TODO: Implement content filtering + // Check for: + // - Profanity or inappropriate language + // - Personal information (emails, API keys, etc.) + // - Nonsensical or irrelevant content + // - Obvious AI-generated artifacts + + // Basic checks for now + message = strings.ToLower(message) + + // Check for placeholder text that AI might generate + prohibitedPhrases := []string{ + "lorem ipsum", + "placeholder", + "todo:", + "fixme:", + "[your text here]", + "replace this", + } + + for _, phrase := range prohibitedPhrases { + if strings.Contains(message, phrase) { + return ErrInappropriateContent + } + } + + return nil +} + +// SanitizeMessage cleans up the generated message +func (mv *MessageValidator) SanitizeMessage(message string) string { + // TODO: Implement message sanitization + // - Remove extra whitespace + // - Fix capitalization + // - Remove trailing periods from subject + // - Ensure proper line breaks + + // Basic cleanup for now + lines := strings.Split(message, "\n") + var cleanLines []string + + for i, line := range lines { + line = strings.TrimSpace(line) + if i == 0 { + // Subject line: remove trailing period, capitalize first letter + line = strings.TrimSuffix(line, ".") + if len(line) > 0 { + line = strings.ToUpper(string(line[0])) + line[1:] + } + } + if line != "" || (i > 0 && i < len(lines)-1) { + cleanLines = append(cleanLines, line) + } + } + + return strings.Join(cleanLines, "\n") +} + +// IsConventionalCommit checks if the message follows conventional commit format +func (mv *MessageValidator) IsConventionalCommit(message string) bool { + // TODO: Implement conventional commit validation + // Check for format: type(scope): description + // Common types: feat, fix, docs, style, refactor, test, chore + + lines := strings.Split(message, "\n") + if len(lines) == 0 { + return false + } + + subject := strings.TrimSpace(lines[0]) + + // Basic pattern matching for conventional commits + conventionalTypes := []string{ + "feat:", "fix:", "docs:", "style:", "refactor:", + "test:", "chore:", "perf:", "ci:", "build:", + "revert:", "merge:", "release:", + } + + for _, ctype := range conventionalTypes { + if strings.HasPrefix(strings.ToLower(subject), ctype) { + return true + } + } + + // Check for scoped format: type(scope): + if strings.Contains(subject, "(") && strings.Contains(subject, "):") { + return true + } + + return false +} diff --git a/pkg/gui/services/ai/models.go b/pkg/gui/services/ai/models.go new file mode 100644 index 00000000000..cdca30fe934 --- /dev/null +++ b/pkg/gui/services/ai/models.go @@ -0,0 +1,64 @@ +package ai + +import "errors" + +// This file contains the data models and interfaces for the AI service. +// These types define the contract between the AI service components and +// external AI providers for generating commit messages. + +// Provider interface for different AI providers (OpenAI, GitHub Copilot, etc.) +type Provider interface { + GenerateMessage(prompt string) (string, error) + Name() string + ValidateConfig() error +} + +// GenerateRequest contains the context for generating a commit message +type GenerateRequest struct { + StagedDiff string `json:"staged_diff"` + FileChanges []FileChange `json:"file_changes"` + BranchName string `json:"branch_name"` + RecentCommits []string `json:"recent_commits"` + ProjectName string `json:"project_name"` + CommitType string `json:"commit_type"` // "new" or "reword" + ExistingMessage string `json:"existing_message,omitempty"` // For reword operations +} + +// GenerateResponse contains the AI-generated commit message +type GenerateResponse struct { + Message string `json:"message"` + Description string `json:"description,omitempty"` + Confidence float64 `json:"confidence"` + Provider string `json:"provider"` +} + +// FileChange represents a change to a file in the commit +type FileChange struct { + Path string `json:"path"` + Status string `json:"status"` // "added", "modified", "deleted", "renamed" + Language string `json:"language"` // Programming language detected + LinesAdded int `json:"lines_added"` + LinesDeleted int `json:"lines_deleted"` + IsBinary bool `json:"is_binary"` +} + +// GitContext contains git repository context for message generation +type GitContext struct { + BranchName string `json:"branch_name"` + ProjectName string `json:"project_name"` + RecentCommits []string `json:"recent_commits"` + RepositoryType string `json:"repository_type"` // detected from files + ConventionalCommits bool `json:"conventional_commits"` // whether to use conventional format +} + +// Common errors +var ( + ErrUnsupportedProvider = errors.New("unsupported AI provider") + ErrNotConfigured = errors.New("AI service is not configured") + ErrInvalidAPIKey = errors.New("invalid API key") + ErrNetworkError = errors.New("network error while calling AI service") + ErrInvalidResponse = errors.New("invalid response from AI service") + ErrEmptyDiff = errors.New("no staged changes to generate commit message for") + ErrMessageTooLong = errors.New("generated commit message is too long") + ErrInappropriateContent = errors.New("generated message contains inappropriate content") +) diff --git a/pkg/gui/services/ai/providers/github.go b/pkg/gui/services/ai/providers/github.go new file mode 100644 index 00000000000..91edfdb917e --- /dev/null +++ b/pkg/gui/services/ai/providers/github.go @@ -0,0 +1,525 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "github.com/jesseduffield/lazygit/pkg/config" +) + +const ( + copilotAuthDeviceCodeURL = "https://github.com/login/device/code" + copilotAuthTokenURL = "https://github.com/login/oauth/access_token" + copilotChatAuthURL = "https://api.github.com/copilot_internal/v2/token" + copilotChatCompletionsURL = "https://api.githubcopilot.com/chat/completions" + copilotEditorVersion = "vscode/1.95.3" + copilotUserAgent = "curl/7.81.0" + copilotClientID = "Iv1.b507a08c87ecfe98" +) + +// GitHubProvider implements the Provider interface for GitHub Copilot API +type GitHubProvider struct { + config config.AIConfig + httpClient *http.Client + accessToken *AccessToken +} + +// AccessToken response from GitHub Copilot's token endpoint +type AccessToken struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` + Endpoints struct { + API string `json:"api"` + OriginTracker string `json:"origin-tracker"` + Proxy string `json:"proxy"` + Telemetry string `json:"telemetry"` + } `json:"endpoints"` + ErrorDetails *struct { + URL string `json:"url,omitempty"` + Message string `json:"message,omitempty"` + Title string `json:"title,omitempty"` + NotificationID string `json:"notification_id,omitempty"` + } `json:"error_details,omitempty"` +} + +type DeviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type DeviceTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error,omitempty"` +} + +type FailedRequestResponse struct { + DocumentationURL string `json:"documentation_url"` + Message string `json:"message"` +} + +type OAuthTokenWrapper struct { + User string `json:"user"` + OAuthToken string `json:"oauth_token"` + GithubAppID string `json:"githubAppId"` +} + +type OAuthToken struct { + GithubWrapper OAuthTokenWrapper `json:"github.com:Iv1.b507a08c87ecfe98"` +} + +// GitHub Copilot API request/response structures +type copilotRequest struct { + Messages []copilotMessage `json:"messages"` + Model string `json:"model"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + N int `json:"n"` + Stream bool `json:"stream"` + MaxTokens int `json:"max_tokens,omitempty"` +} + +type copilotMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type copilotResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []copilotChoice `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + Error *copilotError `json:"error,omitempty"` +} + +type copilotChoice struct { + Index int `json:"index"` + Message copilotMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type copilotError struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` +} + +// NewGitHubProvider creates a new GitHub Copilot provider +func NewGitHubProvider(config config.AIConfig) *GitHubProvider { + return &GitHubProvider{ + config: config, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// GenerateMessage generates a commit message using GitHub Copilot API +func (p *GitHubProvider) GenerateMessage(prompt string) (string, error) { + // Authenticate and get access token + if err := p.authenticate(); err != nil { + return "", fmt.Errorf("failed to authenticate with GitHub Copilot: %w", err) + } + + // Build request + request := p.buildRequest(prompt) + + // Make API call + ctx := context.Background() + response, err := p.makeAPICall(ctx, request) + if err != nil { + return "", fmt.Errorf("failed to call GitHub Copilot API: %w", err) + } + + // Extract message + message, err := p.extractMessage(response) + if err != nil { + return "", fmt.Errorf("failed to extract message from response: %w", err) + } + + return message, nil +} + +// Name returns the provider name +func (p *GitHubProvider) Name() string { + return "github" +} + +// ValidateConfig validates the GitHub Copilot configuration +func (p *GitHubProvider) ValidateConfig() error { + // GitHub Copilot uses OAuth authentication, so we don't require an API key + // We'll validate during authentication flow instead + return nil +} + +// authenticate handles GitHub Copilot authentication +func (p *GitHubProvider) authenticate() error { + // Check if we have a valid access token + if p.accessToken != nil && p.accessToken.ExpiresAt > time.Now().Unix() { + return nil + } + + // Get OAuth token from config files or login flow + oauthToken, err := p.getOAuthToken() + if err != nil { + return fmt.Errorf("failed to get OAuth token: %w", err) + } + + // Get access token for Copilot API + accessToken, err := p.getCopilotAccessToken(oauthToken) + if err != nil { + return fmt.Errorf("failed to get Copilot access token: %w", err) + } + + p.accessToken = &accessToken + return nil +} + +// getOAuthToken gets the OAuth token from config files or initiates login flow +func (p *GitHubProvider) getOAuthToken() (string, error) { + configPath := filepath.Join(os.Getenv("HOME"), ".config/github-copilot") + if runtime.GOOS == "windows" { + configPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "github-copilot") + } + + // Support both legacy and current config file locations + legacyConfigPath := filepath.Join(configPath, "hosts.json") + currentConfigPath := filepath.Join(configPath, "apps.json") + + // Try to get token from config files + configFiles := []string{legacyConfigPath, currentConfigPath} + for _, path := range configFiles { + token, err := p.extractTokenFromFile(path) + if err == nil && token != "" { + return token, nil + } + } + + // If no token found, initiate login flow + return p.loginFlow(currentConfigPath) +} + +// extractTokenFromFile extracts OAuth token from config file +func (p *GitHubProvider) extractTokenFromFile(path string) (string, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read Copilot configuration file at %s: %w", path, err) + } + + var config map[string]json.RawMessage + if err := json.Unmarshal(bytes, &config); err != nil { + return "", fmt.Errorf("failed to parse Copilot configuration file at %s: %w", path, err) + } + + for key, value := range config { + if key == "github.com" || strings.HasPrefix(key, "github.com:") { + var tokenData map[string]string + if err := json.Unmarshal(value, &tokenData); err != nil { + continue + } + if token, exists := tokenData["oauth_token"]; exists { + return token, nil + } + } + } + + return "", fmt.Errorf("no token found in %s", path) +} + +// loginFlow initiates the GitHub OAuth device flow +func (p *GitHubProvider) loginFlow(configPath string) (string, error) { + data := strings.NewReader(fmt.Sprintf("client_id=%s&scope=copilot", copilotClientID)) + req, err := http.NewRequest("POST", copilotAuthDeviceCodeURL, data) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := p.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to get device code: %w", err) + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to decode device code response: %w", err) + } + + parsedData, err := url.ParseQuery(string(responseBody)) + if err != nil { + return "", fmt.Errorf("failed to parse device code response: %w", err) + } + + deviceCodeResp := DeviceCodeResponse{ + UserCode: parsedData.Get("user_code"), + DeviceCode: parsedData.Get("device_code"), + VerificationURI: parsedData.Get("verification_uri"), + } + + deviceCodeResp.ExpiresIn, _ = strconv.Atoi(parsedData.Get("expires_in")) + deviceCodeResp.Interval, _ = strconv.Atoi(parsedData.Get("interval")) + + fmt.Printf("Please go to %s and enter the code %s\n", deviceCodeResp.VerificationURI, deviceCodeResp.UserCode) + + oAuthToken, err := p.fetchRefreshToken(deviceCodeResp.DeviceCode, deviceCodeResp.Interval, deviceCodeResp.ExpiresIn) + if err != nil { + return "", err + } + + err = p.saveOAuthToken(OAuthToken{ + GithubWrapper: OAuthTokenWrapper{ + User: "", + OAuthToken: oAuthToken.AccessToken, + GithubAppID: copilotClientID, + }, + }, configPath) + if err != nil { + return "", err + } + + return oAuthToken.AccessToken, nil +} + +// fetchRefreshToken polls for the OAuth token +func (p *GitHubProvider) fetchRefreshToken(deviceCode string, interval int, expiresIn int) (DeviceTokenResponse, error) { + var accessTokenResp DeviceTokenResponse + var errResp FailedRequestResponse + + time.Sleep(30 * time.Second) // Give user time to open browser + + endTime := time.Now().Add(time.Duration(expiresIn) * time.Second) + ticker := time.NewTicker(time.Duration(interval) * time.Second) + defer ticker.Stop() + + for range ticker.C { + if time.Now().After(endTime) { + return DeviceTokenResponse{}, fmt.Errorf("authorization polling timeout") + } + + fmt.Println("Trying to fetch token...") + data := strings.NewReader(fmt.Sprintf( + "client_id=%s&device_code=%s&grant_type=urn:ietf:params:oauth:grant-type:device_code", + copilotClientID, deviceCode, + )) + + req, err := http.NewRequest("POST", copilotAuthTokenURL, data) + if err != nil { + return DeviceTokenResponse{}, err + } + req.Header.Set("Accept", "application/json") + + resp, err := p.httpClient.Do(req) + if err != nil { + return DeviceTokenResponse{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return DeviceTokenResponse{}, err + } + return DeviceTokenResponse{}, fmt.Errorf("failed to check refresh token: %s", errResp.Message) + } + + if err := json.NewDecoder(resp.Body).Decode(&accessTokenResp); err != nil { + return DeviceTokenResponse{}, err + } + + if accessTokenResp.AccessToken != "" { + return accessTokenResp, nil + } + + if accessTokenResp.Error != "" && accessTokenResp.Error != "authorization_pending" { + return DeviceTokenResponse{}, fmt.Errorf("token error: %s", accessTokenResp.Error) + } + } + + return DeviceTokenResponse{}, fmt.Errorf("authorization polling failed or timed out") +} + +// saveOAuthToken saves the OAuth token to the config file +func (p *GitHubProvider) saveOAuthToken(oAuthToken OAuthToken, configPath string) error { + fileContent, err := json.Marshal(oAuthToken) + if err != nil { + return fmt.Errorf("error marshaling oAuthToken: %w", err) + } + + configDir := filepath.Dir(configPath) + if err = os.MkdirAll(configDir, 0o700); err != nil { + return fmt.Errorf("error creating config directory: %w", err) + } + + err = os.WriteFile(configPath, fileContent, 0o700) + if err != nil { + return fmt.Errorf("error writing oAuthToken to %s: %w", configPath, err) + } + + return nil +} + +// getCopilotAccessToken exchanges OAuth token for Copilot access token +func (p *GitHubProvider) getCopilotAccessToken(oauthToken string) (AccessToken, error) { + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, copilotChatAuthURL, nil) + if err != nil { + return AccessToken{}, fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Authorization", "token "+oauthToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("User-Agent", copilotUserAgent) + + resp, err := p.httpClient.Do(req) + if err != nil { + return AccessToken{}, fmt.Errorf("failed to get access token: %w", err) + } + defer resp.Body.Close() + + var tokenResponse AccessToken + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return AccessToken{}, fmt.Errorf("failed to decode token response: %w", err) + } + + if tokenResponse.ErrorDetails != nil { + return AccessToken{}, fmt.Errorf("token error: %s", tokenResponse.ErrorDetails.Message) + } + + return tokenResponse, nil +} + +// buildRequest creates the GitHub Copilot API request payload +func (p *GitHubProvider) buildRequest(prompt string) *copilotRequest { + model := p.config.Model + if model == "" { + model = "gpt-4" + } + + temperature := p.config.Temperature + if temperature == 0 { + temperature = 0.7 + } + + maxTokens := p.config.MaxTokens + if maxTokens == 0 { + maxTokens = 500 + } + + return &copilotRequest{ + Messages: []copilotMessage{ + { + Role: "system", + Content: "You are a helpful assistant that generates concise, descriptive git commit messages. Focus on what was changed and why, following conventional commit format when appropriate.", + }, + { + Role: "user", + Content: fmt.Sprintf("Generate a git commit message for the following changes:\n\n%s", prompt), + }, + }, + Model: model, + Temperature: temperature, + TopP: 1.0, + N: 1, + Stream: false, + MaxTokens: maxTokens, + } +} + +// makeAPICall makes the HTTP request to GitHub Copilot API +func (p *GitHubProvider) makeAPICall(ctx context.Context, request *copilotRequest) (*copilotResponse, error) { + // Serialize request + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request + url := copilotChatCompletionsURL + if p.config.BaseURL != "" { + url = p.config.BaseURL + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.accessToken.Token) + req.Header.Set("Accept", "application/json") + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("User-Agent", copilotUserAgent) + + // Make request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + // Read response + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(responseBody)) + } + + // Parse response + var apiResponse copilotResponse + if err := json.Unmarshal(responseBody, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Check for API errors + if apiResponse.Error != nil { + return nil, fmt.Errorf("GitHub Copilot API error: %s", apiResponse.Error.Message) + } + + return &apiResponse, nil +} + +// extractMessage extracts the generated message from GitHub Copilot response +func (p *GitHubProvider) extractMessage(response *copilotResponse) (string, error) { + if len(response.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + choice := response.Choices[0] + message := choice.Message.Content + + if message == "" { + return "", fmt.Errorf("empty message in response") + } + + // Clean up the message - remove quotes and extra whitespace + message = strings.Trim(message, "\"'\n\r\t ") + + return message, nil +} diff --git a/pkg/gui/services/ai/providers/github_test.go b/pkg/gui/services/ai/providers/github_test.go new file mode 100644 index 00000000000..c98d781a972 --- /dev/null +++ b/pkg/gui/services/ai/providers/github_test.go @@ -0,0 +1,194 @@ +package providers + +import ( + "testing" + + "github.com/jesseduffield/lazygit/pkg/config" +) + +func TestGitHubProvider_Name(t *testing.T) { + provider := NewGitHubProvider(config.AIConfig{}) + if provider.Name() != "github" { + t.Errorf("Expected provider name to be 'github', got '%s'", provider.Name()) + } +} + +func TestGitHubProvider_ValidateConfig(t *testing.T) { + tests := []struct { + name string + config config.AIConfig + wantErr bool + }{ + { + name: "valid config without API key (OAuth used)", + config: config.AIConfig{}, + wantErr: false, + }, + { + name: "valid config with all fields", + config: config.AIConfig{ + Provider: "github", + Model: "gpt-4", + Temperature: 0.7, + MaxTokens: 500, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := NewGitHubProvider(tt.config) + err := provider.ValidateConfig() + if (err != nil) != tt.wantErr { + t.Errorf("GitHubProvider.ValidateConfig() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestGitHubProvider_buildRequest(t *testing.T) { + config := config.AIConfig{ + Model: "gpt-4", + Temperature: 0.8, + MaxTokens: 1000, + } + provider := NewGitHubProvider(config) + + prompt := "Add user authentication functionality" + request := provider.buildRequest(prompt) + + if request.Model != "gpt-4" { + t.Errorf("Expected model to be 'gpt-4', got '%s'", request.Model) + } + + if request.Temperature != 0.8 { + t.Errorf("Expected temperature to be 0.8, got %f", request.Temperature) + } + + if request.MaxTokens != 1000 { + t.Errorf("Expected max tokens to be 1000, got %d", request.MaxTokens) + } + + if len(request.Messages) != 2 { + t.Errorf("Expected 2 messages, got %d", len(request.Messages)) + } + + if request.Messages[0].Role != "system" { + t.Errorf("Expected first message role to be 'system', got '%s'", request.Messages[0].Role) + } + + if request.Messages[1].Role != "user" { + t.Errorf("Expected second message role to be 'user', got '%s'", request.Messages[1].Role) + } + + if request.Messages[1].Content != "Generate a git commit message for the following changes:\n\nAdd user authentication functionality" { + t.Errorf("Unexpected user message content: %s", request.Messages[1].Content) + } +} + +func TestGitHubProvider_buildRequest_defaults(t *testing.T) { + config := config.AIConfig{} // Empty config to test defaults + provider := NewGitHubProvider(config) + + request := provider.buildRequest("test prompt") + + if request.Model != "gpt-4" { + t.Errorf("Expected default model to be 'gpt-4', got '%s'", request.Model) + } + + if request.Temperature != 0.7 { + t.Errorf("Expected default temperature to be 0.7, got %f", request.Temperature) + } + + if request.MaxTokens != 500 { + t.Errorf("Expected default max tokens to be 500, got %d", request.MaxTokens) + } + + if request.TopP != 1.0 { + t.Errorf("Expected TopP to be 1.0, got %f", request.TopP) + } + + if request.N != 1 { + t.Errorf("Expected N to be 1, got %d", request.N) + } + + if request.Stream != false { + t.Errorf("Expected Stream to be false, got %t", request.Stream) + } +} + +func TestGitHubProvider_extractMessage(t *testing.T) { + tests := []struct { + name string + response *copilotResponse + want string + wantErr bool + }{ + { + name: "valid response", + response: &copilotResponse{ + Choices: []copilotChoice{ + { + Message: copilotMessage{ + Content: "feat: add user authentication", + }, + }, + }, + }, + want: "feat: add user authentication", + wantErr: false, + }, + { + name: "response with quotes and whitespace", + response: &copilotResponse{ + Choices: []copilotChoice{ + { + Message: copilotMessage{ + Content: " \"fix: resolve login bug\" \n", + }, + }, + }, + }, + want: "fix: resolve login bug", + wantErr: false, + }, + { + name: "empty choices", + response: &copilotResponse{ + Choices: []copilotChoice{}, + }, + want: "", + wantErr: true, + }, + { + name: "empty message content", + response: &copilotResponse{ + Choices: []copilotChoice{ + { + Message: copilotMessage{ + Content: "", + }, + }, + }, + }, + want: "", + wantErr: true, + }, + } + + provider := NewGitHubProvider(config.AIConfig{}) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := provider.extractMessage(tt.response) + if (err != nil) != tt.wantErr { + t.Errorf("GitHubProvider.extractMessage() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GitHubProvider.extractMessage() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/gui/services/ai/providers/openai.go b/pkg/gui/services/ai/providers/openai.go new file mode 100644 index 00000000000..e576d16c782 --- /dev/null +++ b/pkg/gui/services/ai/providers/openai.go @@ -0,0 +1,196 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/jesseduffield/lazygit/pkg/config" +) + +// OpenAIProvider implements the Provider interface for OpenAI API +type OpenAIProvider struct { + config config.AIConfig + httpClient *http.Client +} + +// NewOpenAIProvider creates a new OpenAI provider +func NewOpenAIProvider(config config.AIConfig) *OpenAIProvider { + return &OpenAIProvider{ + config: config, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// GenerateMessage generates a commit message using OpenAI API +func (p *OpenAIProvider) GenerateMessage(prompt string) (string, error) { + // TODO: Implement OpenAI API call + // 1. Build request payload + // 2. Make HTTP request to OpenAI API + // 3. Parse response + // 4. Return generated message + + return "", fmt.Errorf("OpenAI provider not implemented yet") +} + +// Name returns the provider name +func (p *OpenAIProvider) Name() string { + return "openai" +} + +// ValidateConfig validates the OpenAI configuration +func (p *OpenAIProvider) ValidateConfig() error { + // TODO: Validate OpenAI configuration + // 1. Check API key is present + // 2. Validate model name + // 3. Check base URL format + // 4. Optionally test API connectivity + + if p.config.APIKey == "" { + return fmt.Errorf("OpenAI API key is required") + } + + if p.config.Model == "" { + return fmt.Errorf("OpenAI model is required") + } + + return nil +} + +// OpenAI API request/response structures +type openAIRequest struct { + Model string `json:"model"` + Messages []openAIMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Stream bool `json:"stream"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type openAIResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []openAIChoice `json:"choices"` + Usage openAIUsage `json:"usage"` + Error *openAIError `json:"error,omitempty"` +} + +type openAIChoice struct { + Index int `json:"index"` + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type openAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` +} + +// buildRequest creates the OpenAI API request payload +func (p *OpenAIProvider) buildRequest(prompt string) *openAIRequest { + // TODO: Build proper request with system and user messages + return &openAIRequest{ + Model: p.config.Model, + Messages: []openAIMessage{ + { + Role: "system", + Content: "You are a helpful assistant that generates concise, clear git commit messages based on code changes.", + }, + { + Role: "user", + Content: prompt, + }, + }, + MaxTokens: p.config.MaxTokens, + Temperature: p.config.Temperature, + Stream: false, + } +} + +// makeAPICall makes the HTTP request to OpenAI API +func (p *OpenAIProvider) makeAPICall(ctx context.Context, request *openAIRequest) (*openAIResponse, error) { + // TODO: Implement actual API call + // 1. Serialize request to JSON + // 2. Create HTTP request with proper headers + // 3. Make request with context + // 4. Parse response + // 5. Handle errors + + // Serialize request + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request + url := p.config.BaseURL + "/chat/completions" + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.config.APIKey) + + // Make request (stub for now) + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + // Read response + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Parse response + var apiResponse openAIResponse + if err := json.Unmarshal(responseBody, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Check for API errors + if apiResponse.Error != nil { + return nil, fmt.Errorf("OpenAI API error: %s", apiResponse.Error.Message) + } + + return &apiResponse, nil +} + +// extractMessage extracts the generated message from OpenAI response +func (p *OpenAIProvider) extractMessage(response *openAIResponse) (string, error) { + if len(response.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + choice := response.Choices[0] + message := choice.Message.Content + + if message == "" { + return "", fmt.Errorf("empty message in response") + } + + return message, nil +} diff --git a/pkg/gui/services/ai/service.go b/pkg/gui/services/ai/service.go new file mode 100644 index 00000000000..2ebe254620b --- /dev/null +++ b/pkg/gui/services/ai/service.go @@ -0,0 +1,79 @@ +package ai + +import ( + "context" + + "github.com/jesseduffield/lazygit/pkg/gui/controllers/helpers" + "github.com/jesseduffield/lazygit/pkg/gui/services/ai/providers" +) + +// Client is the entry point for AI-powered commit message generation +// It follows the same pattern as the custom_commands service +type Client struct { + c *helpers.HelperCommon + contextBuilder *ContextBuilder + messageValidator *MessageValidator +} + +// NewClient creates a new AI service client +func NewClient(c *helpers.HelperCommon) *Client { + contextBuilder := NewContextBuilder(c) + messageValidator := NewMessageValidator() + + return &Client{ + c: c, + contextBuilder: contextBuilder, + messageValidator: messageValidator, + } +} + +// GenerateCommitMessage generates a commit message based on the current git state +func (c *Client) GenerateCommitMessage(ctx context.Context) (*GenerateResponse, error) { + // TODO: Implement commit message generation + // 1. Check if AI is configured and enabled + // 2. Build context from git state + // 3. Call provider to generate message + // 4. Validate and return response + return nil, nil +} + +// GenerateCommitMessageForReword generates an improved commit message for rewording an existing commit +func (c *Client) GenerateCommitMessageForReword(ctx context.Context, existingMessage string, commitSha string) (*GenerateResponse, error) { + // TODO: Implement reword message generation + // 1. Get commit diff and context + // 2. Build prompt with existing message for improvement + // 3. Call provider to generate improved message + // 4. Validate and return response + return nil, nil +} + +// IsConfigured returns whether AI is properly configured +func (c *Client) IsConfigured() bool { + // TODO: Check if AI config is valid and provider is available + config := c.c.UserConfig().AI + return config.Enabled && config.APIKey != "" && config.Provider != "" +} + +// ValidateConfig validates the current AI configuration +func (c *Client) ValidateConfig() error { + // TODO: Validate AI configuration + // 1. Check required fields + // 2. Validate provider settings + // 3. Test connectivity if possible + return nil +} + +// getProvider returns the configured AI provider +func (c *Client) getProvider() (Provider, error) { + // TODO: Initialize and return the appropriate provider based on config + config := c.c.UserConfig().AI + + switch config.Provider { + case "openai": + return providers.NewOpenAIProvider(config), nil + case "github": + return providers.NewGitHubProvider(config), nil + default: + return nil, ErrUnsupportedProvider + } +}