diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 20c478eb4..8b0367c0d 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -18,7 +18,12 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.lsp.helpers import is_LSP_enabled from codeflash.models.ExperimentMetadata import ExperimentMetadata -from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate +from codeflash.models.models import ( + AIServiceRefinerRequest, + CodeStringsMarkdown, + OptimizedCandidate, + OptimizedCandidateSource, +) from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as codeflash_version @@ -27,7 +32,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata - from codeflash.models.models import AIServiceRefinerRequest + from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest from codeflash.result.explanation import Explanation @@ -86,7 +91,9 @@ def make_ai_service_request( # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code return response - def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]: + def _get_valid_candidates( + self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource + ) -> list[OptimizedCandidate]: candidates: list[OptimizedCandidate] = [] for opt in optimizations_json: code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"]) @@ -94,7 +101,10 @@ def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> lis continue candidates.append( OptimizedCandidate( - source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"] + source_code=code, + explanation=opt["explanation"], + optimization_id=opt["optimization_id"], + source=source, ) ) return candidates @@ -157,7 +167,7 @@ def optimize_python_code( # noqa: D417 console.rule() end_time = time.perf_counter() logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.") - return self._get_valid_candidates(optimizations_json) + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE) try: error = response.json()["error"] except Exception: @@ -222,7 +232,7 @@ def optimize_python_code_line_profiler( # noqa: D417 f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information." ) console.rule() - return self._get_valid_candidates(optimizations_json) + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP) try: error = response.json()["error"] except Exception: @@ -275,15 +285,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") console.rule() - refinements = self._get_valid_candidates(refined_optimizations) - return [ - OptimizedCandidate( - source_code=c.source_code, - explanation=c.explanation, - optimization_id=c.optimization_id[:-4] + "refi", - ) - for c in refinements - ] + return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE) try: error = response.json()["error"] @@ -294,6 +296,54 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [] + def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Args: + request: optimization candidate details for refinement + + Returns: + ------- + - OptimizationCandidate: new fixed candidate. + + """ + console.rule() + try: + payload = { + "optimization_id": request.optimization_id, + "original_source_code": request.original_source_code, + "modified_source_code": request.modified_source_code, + "trace_id": request.trace_id, + "test_diffs": request.test_diffs, + "past_trials": request.past_trials, + "trial_no": request.trial_no + } + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception(f"Error generating optimization repair: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + fixed_optimization = response.json() + console.rule() + + valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR) + if not valid_candidates: + logger.error("Code repair failed to generate a valid candidate.") + return None + + return valid_candidates[0] + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return None + def get_new_explanation( # noqa: D417 self, source_code: str, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..d972552d3 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -3,6 +3,7 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING +import libcst as cst from rich.tree import Tree from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log @@ -47,6 +48,36 @@ class AIServiceRefinerRequest: function_references: str | None = None +class TestDiffScope(str, Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + DID_PASS = "did_pass" # noqa: S105 + + +@dataclass +class TestDiff: + scope: TestDiffScope + original_pass: bool + candidate_pass: bool + + original_value: str | None = None + candidate_value: str | None = None + test_src_code: Optional[str] = None + candidate_pytest_error: Optional[str] = None + original_pytest_error: Optional[str] = None + + +@dataclass(frozen=True) +class AIServiceCodeRepairRequest: + optimization_id: str + original_source_code: str + modified_source_code: str + trace_id: str + test_diffs: list[TestDiff] + past_trials: str + trial_no: str + + # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name # of the module is foo.eggs. @@ -354,11 +385,19 @@ class TestsInFile: test_type: TestType +class OptimizedCandidateSource(str, Enum): + OPTIMIZE = "OPTIMIZE" + OPTIMIZE_LP = "OPTIMIZE_LP" + REFINE = "REFINE" + REPAIR = "REPAIR" + + @dataclass(frozen=True) class OptimizedCandidate: source_code: CodeStringsMarkdown explanation: str optimization_id: str + source: OptimizedCandidateSource @dataclass(frozen=True) @@ -505,6 +544,42 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + # TestSuiteClass.test_function_name + def test_fn_qualified_name(self) -> str: + # Use f-string with inline conditional to reduce string concatenation operations + return ( + f"{self.test_class_name}.{self.test_function_name}" + if self.test_class_name + else str(self.test_function_name) + ) + + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: + for stmt in class_node.body.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: + return stmt + return None + + def get_src_code(self, test_path: Path) -> Optional[str]: + if not test_path.exists(): + return None + test_src = test_path.read_text(encoding="utf-8") + module_node = cst.parse_module(test_src) + + if self.test_class_name: + for stmt in module_node.body: + if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: + func_node = self.find_func_in_class(stmt, self.test_function_name) + if func_node: + return module_node.code_for_node(func_node).strip() + # class not found + return None + + # Otherwise, look for a top level function + for stmt in module_node.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name: + return module_node.code_for_node(stmt).strip() + return None + @staticmethod def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") @@ -549,7 +624,10 @@ class TestResults(BaseModel): # noqa: PLW1641 # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] test_result_idx: dict[str, int] = {} + perf_stdout: Optional[str] = None + # mapping between test function name and stdout failure message + test_failures: Optional[dict[str, str]] = None def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2eef51f0f..ad35a6c61 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -5,6 +5,7 @@ import os import queue import random +import sqlite3 import subprocess import time import uuid @@ -48,6 +49,8 @@ N_TESTS_TO_GENERATE_EFFECTIVE, REPEAT_OPTIMIZATION_PROBABILITY, TOTAL_LOOPING_TIME_EFFECTIVE, + MIN_IMPROVEMENT_THRESHOLD, + MIN_TESTCASE_PASSED_THRESHOLD, ) from codeflash.code_utils.deduplicate_code import normalize_code from codeflash.code_utils.edit_generated_tests import ( @@ -67,8 +70,10 @@ from codeflash.either import Failure, Success, is_successful from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId +from codeflash.models import models from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + AIServiceCodeRepairRequest, BestOptimization, CodeOptimizationContext, GeneratedTests, @@ -76,12 +81,13 @@ OptimizationSet, OptimizedCandidate, OptimizedCandidateResult, + OptimizedCandidateSource, OriginalCodeBaseline, TestFile, TestFiles, TestingMode, TestResults, - TestType, + TestType, TestDiffScope, ) from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.critic import ( @@ -113,9 +119,149 @@ CoverageData, FunctionCalledInTest, FunctionSource, + TestDiff, ) from codeflash.verification.verification_utils import TestConfig +CODE_REPAIR_LOG_DB = Path("/Users/aseemsaxena/Downloads/codeflash_dev/codeflash-internal/django/aiservice/code_repair/code_repair_log.db") + + +SCOPE_DESCRIPTIONS = { + TestDiffScope.RETURN_VALUE: ( + "The function returned a different value in the optimized code compared to the original." + ), + TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."), + TestDiffScope.DID_PASS: ( + "The test passed in one version but failed in the other (a change in pass/fail behavior)." + ), +} + +def build_test_details(test_diffs: list[TestDiff]) -> str: + sections = [] + for test_no, diff in enumerate(test_diffs, 1): + test_src_code = "```python\n" + diff.test_src_code + "\n```" if diff.test_src_code else "" + section = [ + f"#### Test #{test_no}", + f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}", + f"Expected: {diff.original_value!r}. Got: {diff.candidate_value!r}" + if diff.scope != TestDiffScope.DID_PASS + else "", + f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}", + f"Pytest error (original code): {diff.original_pytest_error}" if diff.original_pytest_error else "", + f"Pytest error (optimized code): {diff.candidate_pytest_error}" if diff.candidate_pytest_error else "", + "Test Source:", + test_src_code, + "---", + ] + sections.append("\n".join(filter(None, section))) + + return "\n".join(sections) + +def _init_code_repair_log_db() -> None: + """Initialize the SQLite database for code repair logging.""" + conn = sqlite3.connect(CODE_REPAIR_LOG_DB) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_repair_logs ( + optimization_id TEXT PRIMARY KEY, + trace_id TEXT, + user_prompt TEXT, + explanation TEXT, + refined_optimization TEXT, + trial_no TEXT, + past_trials TEXT, + passed TEXT, + faster TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + conn.close() + + +def log_code_repair_to_db( + optimization_id: str, + trace_id: str | None = None, + user_prompt: str | None = None, + explanation: str | None = None, + refined_optimization: str | None = None, + trial_no: str | None = None, + past_trials: str | None = None, + passed: str | None = None, + faster: str | None = None, +) -> None: + """Log code repair data to SQLite database. + + Uses upsert pattern to allow incremental logging with different columns at different places. + Only non-None values will be updated; existing values are preserved. + """ + try: + _init_code_repair_log_db() + conn = sqlite3.connect(CODE_REPAIR_LOG_DB) + cursor = conn.cursor() + + # Build dynamic upsert query based on provided columns + columns = ["optimization_id"] + values = [optimization_id] + update_parts = ["updated_at = CURRENT_TIMESTAMP"] + + if trace_id is not None: + columns.append("trace_id") + values.append(trace_id) + update_parts.append("trace_id = excluded.trace_id") + + if user_prompt is not None: + columns.append("user_prompt") + values.append(user_prompt) + update_parts.append("user_prompt = excluded.user_prompt") + + if explanation is not None: + columns.append("explanation") + values.append(explanation) + update_parts.append("explanation = excluded.explanation") + + if refined_optimization is not None: + columns.append("refined_optimization") + values.append(refined_optimization) + update_parts.append("refined_optimization = excluded.refined_optimization") + + if trial_no is not None: + columns.append("trial_no") + values.append(trial_no) + update_parts.append("trial_no = excluded.trial_no") + + if past_trials is not None: + columns.append("past_trials") + values.append(past_trials) + update_parts.append("past_trials = excluded.past_trials") + + if passed is not None: + columns.append("passed") + values.append(passed) + update_parts.append("passed = excluded.passed") + + if faster is not None: + columns.append("faster") + values.append(faster) + update_parts.append("faster = excluded.faster") + + placeholders = ", ".join(["?"] * len(values)) + columns_str = ", ".join(columns) + update_str = ", ".join(update_parts) + + cursor.execute( + f""" + INSERT INTO code_repair_logs ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT(optimization_id) DO UPDATE SET {update_str} + """, # noqa: S608 + values, + ) + conn.commit() + conn.close() + except Exception: + logger.exception("Failed to log code repair data to SQLite") class CandidateProcessor: """Handles candidate processing using a queue-based approach.""" @@ -124,7 +270,8 @@ def __init__( self, initial_candidates: list, future_line_profile_results: concurrent.futures.Future, - future_all_refinements: list, + future_all_refinements: list[concurrent.futures.Future], + future_all_code_repair: list[concurrent.futures.Future], ) -> None: self.candidate_queue = queue.Queue() self.line_profiler_done = False @@ -137,6 +284,7 @@ def __init__( self.future_line_profile_results = future_line_profile_results self.future_all_refinements = future_all_refinements + self.future_all_code_repair = future_all_code_repair def get_next_candidate(self) -> OptimizedCandidate | None: """Get the next candidate from the queue, handling async results as needed.""" @@ -149,6 +297,8 @@ def _handle_empty_queue(self) -> OptimizedCandidate | None: """Handle empty queue by checking for pending async results.""" if not self.line_profiler_done: return self._process_line_profiler_results() + if len(self.future_all_code_repair) > 0: + return self._process_code_repair() if self.line_profiler_done and not self.refinement_done: return self._process_refinement_results() return None # All done @@ -188,10 +338,30 @@ def _process_refinement_results(self) -> OptimizedCandidate | None: logger.info( f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}" ) + self.future_all_refinements = [] self.refinement_done = True return self.get_next_candidate() + def _process_code_repair(self) -> OptimizedCandidate | None: + logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates") + concurrent.futures.wait(self.future_all_code_repair) + candidates_added = 0 + for future_code_repair in self.future_all_code_repair: + possible_code_repair = future_code_repair.result() + if possible_code_repair: + self.candidate_queue.put(possible_code_repair) + self.candidate_len += 1 + candidates_added += 1 + + if candidates_added > 0: + logger.info( + f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}" + ) + self.future_all_code_repair = [] + + return self.get_next_candidate() + def is_done(self) -> bool: """Check if processing is complete.""" return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty() @@ -247,6 +417,9 @@ def __init__( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 ) self.optimization_review = "" + self.ast_code_to_id = {} + self.future_all_refinements: list[concurrent.futures.Future] = [] + self.future_all_code_repair: list[concurrent.futures.Future] = [] def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -387,7 +560,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print( @@ -459,6 +631,48 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def reset_optimization_metrics_for_candidate( + self, opt_id: str, speedup_ratios: dict, is_correct: dict, optimized_runtimes: dict + ) -> None: + speedup_ratios[opt_id] = None + is_correct[opt_id] = False + optimized_runtimes[opt_id] = None + + def was_candidate_tested_before(self, normalized_code: str) -> bool: + # check if this code has been evaluated before by checking the ast normalized code string + return normalized_code in self.ast_code_to_id + + def update_results_for_duplicate_candidate( + self, + candidate: OptimizedCandidate, + code_context: CodeOptimizationContext, + normalized_code: str, + speedup_ratios: dict, + is_correct: dict, + optimized_runtimes: dict, + optimized_line_profiler_results: dict, + optimizations_post: dict, + ) -> None: + logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.") + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes + speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] + is_correct[candidate.optimization_id] = is_correct[past_opt_id] + optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] + # line profiler results only available for successful runs + if past_opt_id in optimized_line_profiler_results: + optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[past_opt_id] + optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if ( + new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"] + ): # new candidate has a shorter diff than the previously encountered one + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + def determine_best_candidate( self, *, @@ -484,8 +698,10 @@ def determine_best_candidate( ) console.rule() - future_all_refinements: list[concurrent.futures.Future] = [] - ast_code_to_id = {} + self.ast_code_to_id.clear() + self.future_all_refinements.clear() + self.future_all_code_repair.clear() + valid_optimizations = [] optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated @@ -506,7 +722,9 @@ def determine_best_candidate( ) # Initialize candidate processor - processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) + processor = CandidateProcessor( + candidates, future_line_profile_results, self.future_all_refinements, self.future_all_code_repair + ) candidate_index = 0 # Process candidates using queue-based approach @@ -548,47 +766,40 @@ def determine_best_candidate( continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." + if self.was_candidate_tested_before(normalized_code): + self.update_results_for_duplicate_candidate( + candidate=candidate, + code_context=code_context, + normalized_code=normalized_code, + speedup_ratios=speedup_ratios, + is_correct=is_correct, + optimized_runtimes=optimized_runtimes, + optimized_line_profiler_results=optimized_line_profiler_results, + optimizations_post=optimizations_post, ) - past_opt_id = ast_code_to_id[normalized_code]["optimization_id"] - # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes - speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] - is_correct[candidate.optimization_id] = is_correct[past_opt_id] - optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] - # line profiler results only available for successful runs - if past_opt_id in optimized_line_profiler_results: - optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ - past_opt_id - ] - optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) - if ( - new_diff_len < ast_code_to_id[normalized_code]["diff_len"] - ): # new candidate has a shorter diff than the previously encountered one - ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - ast_code_to_id[normalized_code]["diff_len"] = new_diff_len continue - ast_code_to_id[normalized_code] = { + self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } + run_results = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, + candidate=candidate, + exp_type=exp_type, + original_code_baseline=original_code_baseline ) + console.rule() if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None + self.reset_optimization_metrics_for_candidate( + candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes + ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime @@ -672,21 +883,21 @@ def determine_best_candidate( async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization - if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) + # # queue corresponding refined optimization for best optimization + # if candidate.source != OptimizedCandidateSource.REFINE: + # self.future_all_refinements.append( + # self.refine_optimizations( + # valid_optimizations=[best_optimization], + # original_code_baseline=original_code_baseline, + # code_context=code_context, + # trace_id=self.function_trace_id[:-4] + exp_type + # if self.experiment_id + # else self.function_trace_id, + # ai_service_client=ai_service_client, + # executor=self.executor, + # function_references=function_references, + # ) + # ) else: # For async functions, prioritize throughput metrics over runtime even for slow candidates is_async = ( @@ -742,9 +953,10 @@ def determine_best_candidate( for valid_opt in valid_optimizations: valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) new_candidate_with_shorter_code = OptimizedCandidate( - source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + source_code=self.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation, + source=valid_opt.candidate.source, ) new_best_opt = BestOptimization( candidate=new_candidate_with_shorter_code, @@ -839,6 +1051,28 @@ def refine_optimizations( ] return executor.submit(ai_service_client.optimize_python_code_refinement, request=request) + def repair_optimization( + self, + original_source_code: str, + modified_source_code: str, + test_diffs: list[TestDiff], + trace_id: str, + optimization_id: str, + past_trials: str, + trial_no: str, + ai_service_client: AiServiceClient, + ) -> OptimizedCandidate | None: + request = AIServiceCodeRepairRequest( + optimization_id=optimization_id, + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_diffs=test_diffs, + trace_id=trace_id, + past_trials=past_trials, + trial_no=trial_no + ) + return ai_service_client.optimize_python_code_repair(request=request) + def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: @@ -1752,6 +1986,11 @@ def establish_original_code_baseline( ) ) + def get_results_not_matched_error(self) -> Failure: + logger.info("h4|Test results did not match the test results of the original code ❌") + console.rule() + return Failure("Test results did not match the test results of the original code.") + def run_optimized_candidate( self, *, @@ -1759,155 +1998,248 @@ def run_optimized_candidate( baseline_results: OriginalCodeBaseline, original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], + code_context: CodeOptimizationContext, + candidate: OptimizedCandidate, + exp_type: str, + original_code_baseline, # noqa: ANN001 ) -> Result[OptimizedCandidateResult, str]: - assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 - - with progress_bar("Testing optimization candidate"): - test_env = self.get_test_env( - codeflash_loop_index=0, - codeflash_test_iteration=optimization_candidate_index, - codeflash_tracer_disable=1, - ) - - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) - # Instrument codeflash capture - candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") - candidate_helper_code = {} - for module_abspath in original_helper_code: - candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") - if self.function_to_optimize.is_async: - from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function - - add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR - ) - - try: - instrument_codeflash_capture( - self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + current_candidate = candidate + current_candidate_index = optimization_candidate_index + past_trials = "" + for trial_no in range(4): + print("Trial no: ", trial_no) + assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 + with progress_bar("Testing optimization candidate"): + test_env = self.get_test_env( + codeflash_loop_index=0, + codeflash_test_iteration=current_candidate_index, + codeflash_tracer_disable=1, ) - - total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE - candidate_behavior_results, _ = self.run_and_parse_tests( - testing_type=TestingMode.BEHAVIOR, - test_env=test_env, - test_files=self.test_files, - optimization_iteration=optimization_candidate_index, - testing_time=total_looping_time, - enable_coverage=False, - ) - # Remove instrumentation - finally: - self.write_code_and_helpers( - candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path - ) - console.print( - TestResults.report_to_tree( - candidate_behavior_results.get_test_pass_fail_report_by_type(), - title=f"Behavioral Test Results for candidate {optimization_candidate_index}", - ) - ) - console.rule() - if compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results): - logger.info("h3|Test results matched ✅") - console.rule() - else: - logger.info("h4|Test results did not match the test results of the original code ❌") - console.rule() - return Failure("Test results did not match the test results of the original code.") - - logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") - - if test_framework == "pytest": - # For async functions, instrument at definition site for performance benchmarking + get_run_tmp_file(Path(f"test_return_values_{current_candidate_index}.sqlite")).unlink(missing_ok=True) + # Instrument codeflash capture + candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") + candidate_helper_code = {} + for module_abspath in original_helper_code: + candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") if self.function_to_optimize.is_async: from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function - add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR ) - try: - candidate_benchmarking_results, _ = self.run_and_parse_tests( - testing_type=TestingMode.PERFORMANCE, + instrument_codeflash_capture( + self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + ) + total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE + candidate_behavior_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, test_env=test_env, test_files=self.test_files, - optimization_iteration=optimization_candidate_index, + optimization_iteration=current_candidate_index, testing_time=total_looping_time, enable_coverage=False, ) + # Remove instrumentation finally: - # Restore original source if we instrumented it - if self.function_to_optimize.is_async: - self.write_code_and_helpers( - candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path - ) - loop_count = ( - max(all_loop_indices) - if ( - all_loop_indices := { - result.loop_index for result in candidate_benchmarking_results.test_results - } + self.write_code_and_helpers( + candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path + ) + console.print( + TestResults.report_to_tree( + candidate_behavior_results.get_test_pass_fail_report_by_type(), + title=f"Behavioral Test Results for candidate {current_candidate_index}", ) - else 0 ) - - else: - candidate_benchmarking_results = TestResults() - start_time: float = time.time() - loop_count = 0 - for i in range(100): - if i >= 5 and time.time() - start_time >= TOTAL_LOOPING_TIME_EFFECTIVE * 1.5: - # * 1.5 to give unittest a bit more time to run + console.rule() + match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + if match: + logger.info("h3|Test results matched ✅") + console.rule() + if trial_no!=0: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no), + passed="yes", + ) + break + if trial_no<=2: + # repair process + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + # first candidate + repair_candidate = self.repair_optimization( + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=current_candidate.source_code.markdown, + test_diffs=diffs, + trace_id=self.function_trace_id, + ai_service_client=ai_service_client, + optimization_id=candidate.optimization_id, + past_trials=past_trials, + trial_no=str(trial_no+1) + ) + if not repair_candidate: + logger.debug("llm call failed") + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no+1), + passed="no", + faster="no" + ) + match = False + if trial_no != 2: + continue break - test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1) - unittest_loop_results, _cov = self.run_and_parse_tests( - testing_type=TestingMode.PERFORMANCE, - test_env=test_env, - test_files=self.test_files, - optimization_iteration=optimization_candidate_index, - testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, - unittest_loop_index=i + 1, + try: + # update code + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=repair_candidate.source_code, + original_helper_code=original_helper_code, + ) + if not did_update: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no + 1), + passed="no", + faster="no", + ) + match = False + if trial_no != 2: + continue + break + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no + 1), + passed="no", + faster="no", + ) + match = False + if trial_no != 2: + continue + break + past_trials += f"Trial {trial_no + 1}\n" + past_trials += f"Candidate Code\n{current_candidate.source_code.markdown}\n" + past_trials += "Abridged test results\n" + past_trials += build_test_details(diffs)[:2000] + current_candidate = repair_candidate + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no + 1), + passed="no", + faster="no", ) - loop_count = i + 1 - candidate_benchmarking_results.merge(unittest_loop_results) - - if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: - logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") - console.rule() - - logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") + # behavior to test, if pass break + # log the results + # return self.get_results_not_matched_error() + if not match: + print("didn't work after 3 trials abort") + if trial_no!=0: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no), + passed="no", + faster="no" + ) + return self.get_results_not_matched_error() + # performance benchmark + logger.info(f"loading|Running performance tests for candidate {current_candidate_index}...") - candidate_async_throughput = None + if test_framework == "pytest": + # For async functions, instrument at definition site for performance benchmarking if self.function_to_optimize.is_async: - candidate_async_throughput = calculate_function_throughput_from_test_results( - candidate_benchmarking_results, self.function_to_optimize.function_name + from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function + + add_async_decorator_to_function( + self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE ) - logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second") - if self.args.benchmark: - candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( - self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + try: + candidate_benchmarking_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.PERFORMANCE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=current_candidate_index, + testing_time=total_looping_time, + enable_coverage=False, ) - for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): - logger.debug( - f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}" + finally: + # Restore original source if we instrumented it + if self.function_to_optimize.is_async: + self.write_code_and_helpers( + candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path ) - return Success( - OptimizedCandidateResult( - max_loop_count=loop_count, - best_test_runtime=total_candidate_timing, - behavior_test_results=candidate_behavior_results, - benchmarking_test_results=candidate_benchmarking_results, - replay_benchmarking_test_results=candidate_replay_benchmarking_results - if self.args.benchmark - else None, - optimization_candidate_index=optimization_candidate_index, - total_candidate_timing=total_candidate_timing, - async_throughput=candidate_async_throughput, + loop_count = ( + max(all_loop_indices) + if ( + all_loop_indices := { + result.loop_index for result in candidate_benchmarking_results.test_results + } ) + else 0 ) + if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: + logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") + console.rule() + + logger.debug(f"Total optimized code {current_candidate_index} runtime (ns): {total_candidate_timing}") + + candidate_async_throughput = None + if self.function_to_optimize.is_async: + candidate_async_throughput = calculate_function_throughput_from_test_results( + candidate_benchmarking_results, self.function_to_optimize.function_name + ) + logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second") + + if self.args.benchmark: + candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + ) + for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): + logger.debug( + f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}" + ) + best_test_runtime = total_candidate_timing + perf_gain = performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime + ) + noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_baseline.runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD + #log here again + report = candidate_behavior_results.get_test_pass_fail_report_by_type() + pass_count = 0 + for test_type in report: + pass_count += report[test_type]["passed"] + + if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD: + speedup_critic_val = True + # If one or more tests passed, check if least one of them was a successful REPLAY_TEST + speedup_critic_val = bool(pass_count >= 6) + faster = "yes" if (perf_gain > noise_floor and speedup_critic_val) else "no" + if trial_no!=0: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no), + faster=faster, + ) + return Success( + OptimizedCandidateResult( + max_loop_count=loop_count, + best_test_runtime=total_candidate_timing, + behavior_test_results=candidate_behavior_results, + benchmarking_test_results=candidate_benchmarking_results, + replay_benchmarking_test_results=candidate_replay_benchmarking_results + if self.args.benchmark + else None, + optimization_candidate_index=optimization_candidate_index, + total_candidate_timing=total_candidate_timing, + async_throughput=candidate_async_throughput, + ) + ) + def run_and_parse_tests( self, testing_type: TestingMode, diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 9d7f5ba2c..ef7fb910d 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,27 +1,52 @@ +from __future__ import annotations + import sys +from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger -from codeflash.models.models import TestResults, TestType, VerificationType +from codeflash.models.models import TestDiff, TestDiffScope, TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator +if TYPE_CHECKING: + from codeflash.models.models import TestResults + INCREASED_RECURSION_LIMIT = 5000 -def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> bool: +def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: - return False # empty test results are not equal + return False, [] # empty test results are not equal original_recursion_limit = sys.getrecursionlimit() if original_recursion_limit < INCREASED_RECURSION_LIMIT: sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( set(candidate_results.get_all_unique_invocation_loop_ids()) ) - are_equal: bool = True + test_diffs: list[TestDiff] = [] did_all_timeout: bool = True for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) + candidate_test_failures = candidate_results.test_failures + original_test_failures = original_results.test_failures + try: + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if candidate_test_failures + else "" + ) + except: + cdd_pytest_error = "" + try: + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" + ) + except: + original_pytest_error="" + if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -32,8 +57,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ): continue if original_test_result is None or cdd_test_result is None: - are_equal = False - break + continue did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue @@ -43,42 +67,53 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) + test_diff = TestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_value=repr(original_test_result.return_value), + candidate_value=repr(cdd_test_result.return_value), + test_src_code=test_src_code, + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - are_equal = False + test_diff.scope = TestDiffScope.RETURN_VALUE + test_diffs.append(test_diff) + try: logger.debug( - "File Name: %s\n" - "Test Type: %s\n" - "Verification Type: %s\n" - "Invocation ID: %s\n" - "Original return value: %s\n" - "Candidate return value: %s\n" - "-------------------", - original_test_result.file_name, - original_test_result.test_type, - original_test_result.verification_type, - original_test_result.id, - original_test_result.return_value, - cdd_test_result.return_value, + f"File Name: {original_test_result.file_name}\n" + f"Test Type: {original_test_result.test_type}\n" + f"Verification Type: {original_test_result.verification_type}\n" + f"Invocation ID: {original_test_result.id}\n" + f"Original return value: {original_test_result.return_value}\n" + f"Candidate return value: {cdd_test_result.return_value}\n" ) except Exception as e: logger.error(e) - break - if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( + elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - are_equal = False - break + test_diff.scope = TestDiffScope.STDOUT + test_diff.original_value = str(original_test_result.stdout) + test_diff.candidate_value = str(cdd_test_result.stdout) + test_diffs.append(test_diff) - if original_test_result.test_type in { + elif original_test_result.test_type in { TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - are_equal = False - break + test_diff.scope = TestDiffScope.DID_PASS + test_diff.original_value = str(original_test_result.did_pass) + test_diff.candidate_value = str(cdd_test_result.did_pass) + test_diffs.append(test_diff) + sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: - return False - return are_equal + return False, test_diffs + return len(test_diffs) == 0, test_diffs diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index ef513a0a3..f5cdad9d1 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,6 +512,61 @@ def merge_test_results( return merged_test_results +FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+") +TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$") + + +def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: + """Extract individual pytest test failures from stdout grouped by test case qualified name, and add them to the test results.""" + lines = stdout.splitlines() + start = end = None + + for i, line in enumerate(lines): + if FAILURES_HEADER_RE.search(line.strip()): + start = i + break + + if start is None: + return test_results + + for j in range(start + 1, len(lines)): + stripped = lines[j].strip() + if "short test summary info" in stripped: + end = j + break + # any new === section === block + if stripped.startswith("=") and stripped.count("=") > 3: + end = j + break + + # If no clear "end", just grap the rest of the string + if end is None: + end = len(lines) + + failure_block = lines[start:end] + + failures: dict[str, str] = {} + current_name = None + current_lines: list[str] = [] + + for line in failure_block: + m = TEST_HEADER_RE.match(line.strip()) + if m: + if current_name is not None: + failures[current_name] = "".join(current_lines) + + current_name = m.group(1) + current_lines = [] + elif current_name: + current_lines.append(line + "\n") + + if current_name: + failures[current_name] = "".join(current_lines) + + test_results.test_failures = failures + return test_results + + def parse_test_results( test_xml_path: Path, test_files: TestFiles, @@ -572,4 +627,9 @@ def parse_test_results( function_name=function_name, ) coverage.log_coverage() + try: + parse_test_failures_from_stdout(results, run_result.stdout) + except Exception as e: + logger.exception(e) + return results, coverage if all_args else None diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index c326cecc4..79133bc15 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -502,7 +502,8 @@ def __init__(self, x=2): pytest_max_loops=1, testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -626,7 +627,8 @@ def __init__(self, *args, **kwargs): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -754,7 +756,8 @@ def __init__(self, x=2): testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) sample_code_path.unlink(missing_ok=True) @@ -902,7 +905,8 @@ def another_helper(self): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -1132,7 +1136,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert not compare_test_results(test_results, mutated_test_results) + match, _ = compare_test_results(test_results, mutated_test_results) + assert not match # This fto code stopped using a helper class. it should still pass no_helper1_fto_code = """ @@ -1170,10 +1175,304 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert compare_test_results(test_results, no_helper1_test_results) + match, _ = compare_test_results(test_results, no_helper1_test_results) + assert match finally: test_path.unlink(missing_ok=True) fto_file_path.unlink(missing_ok=True) helper_path_1.unlink(missing_ok=True) helper_path_2.unlink(missing_ok=True) + +def test_instrument_codeflash_capture_and_run_tests_2() -> None: + # End to end run that instruments code and runs tests. Made to be similar to code used in the optimizer.py + test_code = """import math +import pytest +from typing import List, Tuple, Optional +from code_to_optimize.tests.pytest.fto_file import calculate_portfolio_metrics + +def test_calculate_portfolio_metrics(): + # Test case 1: Basic portfolio + investments = [ + ('Stocks', 0.6, 0.12), + ('Bonds', 0.3, 0.04), + ('Cash', 0.1, 0.01) + ] + + result = calculate_portfolio_metrics(investments) + + # Check weighted return calculation + expected_return = 0.6*0.12 + 0.3*0.04 + 0.1*0.01 + assert abs(result['weighted_return'] - expected_return) < 1e-10 + + # Check volatility calculation + expected_vol = math.sqrt((0.6*0.12)**2 + (0.3*0.04)**2 + (0.1*0.01)**2) + assert abs(result['volatility'] - expected_vol) < 1e-10 + + # Check Sharpe ratio + expected_sharpe = (expected_return - 0.02) / expected_vol + assert abs(result['sharpe_ratio'] - expected_sharpe) < 1e-10 + + # Check best/worst performers + assert result['best_performing'][0] == 'Stocks' + assert result['worst_performing'][0] == 'Cash' + assert result['total_assets'] == 3 + +def test_empty_investments(): + with pytest.raises(ValueError, match="Investments list cannot be empty"): + calculate_portfolio_metrics([]) + +def test_weights_not_sum_to_one(): + investments = [('Stock', 0.5, 0.1), ('Bond', 0.4, 0.05)] + with pytest.raises(ValueError, match="Portfolio weights must sum to 1.0"): + calculate_portfolio_metrics(investments) + +def test_zero_volatility(): + investments = [('Cash', 1.0, 0.0)] + result = calculate_portfolio_metrics(investments, risk_free_rate=0.0) + assert result['sharpe_ratio'] == 0.0 + assert result['volatility'] == 0.0 +""" + + original_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Calculate weighted return + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Calculate portfolio volatility (simplified) + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Calculate Sharpe ratio + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Find best and worst performing assets + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + 'weighted_return': round(weighted_return, 6), + 'volatility': round(volatility, 6), + 'sharpe_ratio': round(sharpe_ratio, 6), + 'best_performing': (best_asset[0], round(best_asset[2], 6)), + 'worst_performing': (worst_asset[0], round(worst_asset[2], 6)), + 'total_assets': len(investments) + } +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + test_file_name = "test_multiple_helpers.py" + + fto_file_name = "fto_file.py" + + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_multiple_helpers_perf.py" + fto_file_path = test_dir / fto_file_name + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + + try: + with fto_file_path.open("w") as f: + f.write(original_code) + with test_path.open("w") as f: + f.write(test_code) + + fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[]) + file_path_to_helper_class = { + } + instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root) + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + # Code in optimizer.py + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + + # Now, let's say we optimize the code and make changes. + new_fto_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + total_weight = sum(w for _, w, _ in investments) + if total_weight != 1.0: # Should use tolerance check + raise ValueError("Portfolio weights must sum to 1.0") + + weighted_return = 1.0 + for _, weight, ret in investments: + weighted_return *= (1 + ret) ** weight + weighted_return = weighted_return - 1.0 # Convert back from geometric + + returns = [r for _, _, r in investments] + mean_return = sum(returns) / len(returns) + volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns)) + + # BUG 4: Sharpe ratio calculation is correct but uses wrong inputs + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + def risk_adjusted_return(return_val, weight): + return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val + + best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": 2, + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fto_code) + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results) + + assert not matched + + new_fixed_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + # Tolerant weight check (matches original) + total_weight = sum(weight for _, weight, _ in investments) + if abs(total_weight - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Same weighted return as original + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Same volatility formula as original + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Same Sharpe ratio logic + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Same best/worst logic (based on return only) + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": round(volatility, 6), + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fixed_code) + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results_2, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results_2) + # now the test should match and no diffs should be found + assert len(diffs) == 0 + assert matched + + finally: + test_path.unlink(missing_ok=True) + fto_file_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06d178f95..6c2781229 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1176,7 +1176,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_1) + match, _ = compare_test_results(original_results, new_results_1) + assert match new_results_2 = TestResults() new_results_2.add( @@ -1199,7 +1200,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_2) + match, _ = compare_test_results(original_results, new_results_2) + assert not match new_results_3 = TestResults() new_results_3.add( @@ -1241,7 +1243,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_3) + match, _ = compare_test_results(original_results, new_results_3) + assert match new_results_4 = TestResults() new_results_4.add( @@ -1264,7 +1267,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_4) + match, _ = compare_test_results(original_results, new_results_4) + assert not match new_results_5_baseline = TestResults() new_results_5_baseline.add( @@ -1308,7 +1312,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_5_baseline, new_results_5_opt) + match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt) + assert not match new_results_6_baseline = TestResults() new_results_6_baseline.add( @@ -1352,9 +1357,11 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_6_baseline, new_results_6_opt) + match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt) + assert not match - assert not compare_test_results(TestResults(), TestResults()) + match, _ = compare_test_results(TestResults(), TestResults()) + assert not match def test_exceptions(): diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ece7d38b0..7bdfa364b 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -223,7 +223,8 @@ def test_sort(): result: [0, 1, 2, 3, 4, 5] """ assert out_str == results2[0].stdout - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) @@ -368,7 +369,8 @@ def test_sort(): assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],) out_str = """codeflash stdout : BubbleSorter.sorter() called\n""" assert test_results[1].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__" assert test_results[2].id.test_function_name == "test_sort" assert test_results[2].did_pass @@ -396,7 +398,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match # Replace with optimized code that mutated instance attribute optimized_code = """ @@ -491,7 +494,8 @@ def sorter(self, arr): ) assert new_test_results[3].runtime > 0 assert new_test_results[3].did_pass - assert not compare_test_results(test_results, new_test_results) + match, _ = compare_test_results(test_results, new_test_results) + assert not match finally: fto_path.write_text(original_code, "utf-8") @@ -630,7 +634,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_classmethod" assert test_results[1].id.iteration_id == "4_0" @@ -655,7 +660,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") @@ -794,7 +800,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_staticmethod" assert test_results[1].id.iteration_id == "4_0" @@ -819,7 +826,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index cae2c76f1..03556718d 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -221,10 +221,10 @@ def sorter(self, arr): testing_time=0.1, ) # assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function - assert compare_test_results( + match, _ = compare_test_results( test_results, test_results_mutated_attr ) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated - + assert match assert test_results_mutated_attr[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n" finally: fto_path.write_text(original_code, "utf-8") @@ -403,9 +403,10 @@ def sorter(self, arr): assert test_results_mutated_attr[0].return_value[0] == {"x": 1} assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO assert test_results_mutated_attr[0].stdout == "" - assert not compare_test_results( + match,_ = compare_test_results( test_results, test_results_mutated_attr ) # The test should fail because the instance attribute was mutated + assert not match # Replace with optimized code that did not mutate existing instance attribute, but added a new one optimized_code_new_attr = """ import sys @@ -457,9 +458,10 @@ def sorter(self, arr): assert test_results_new_attr[0].stdout == "" # assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input # assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input - assert compare_test_results( + match,_ = compare_test_results( test_results, test_results_new_attr ) # The test should pass because the instance attribute was not mutated, only a new one was added + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index c67883c12..c05384d03 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -427,8 +427,8 @@ def bubble_sort_with_unused_socket(data_container): testing_time=1.0, ) assert len(optimized_test_results_unused_socket) == 1 - verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) - assert verification_result is True + match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert match # Remove the previous instrumentation replay_test_path.write_text(original_replay_test_code) @@ -517,8 +517,8 @@ def bubble_sort_with_used_socket(data_container): assert test_results_used_socket.test_results[0].did_pass is False # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. - assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False - + match, _ = compare_test_results(test_results_used_socket, optimized_test_results_used_socket) + assert not match finally: # cleanup output_file.unlink(missing_ok=True)