diff --git a/sacrebleu.py b/sacrebleu.py index bd1160f5..74186c43 100755 --- a/sacrebleu.py +++ b/sacrebleu.py @@ -673,25 +673,22 @@ def download_test_set(test_set, langpair=None): BLEU = namedtuple('BLEU', 'score, ngram1, ngram2, ngram3, ngram4, bp, sys_len, ref_len') -def compute_bleu(instream, refstreams, smooth=0., force=False, lc=False, tokenize=False) -> BLEU: +def compute_bleu(instream, refstreams, smooth=0., force=False, lc=False, tokenize=False, bootstrap_trials=1) -> BLEU: """Produces the BLEU scores along with its sufficient statistics from a source against one or more references. :param instream: the input stream, one segment per line :param refstreams: a list of reference streams + :param bootstrap_trials=1: number of trials for bootstrap resampling :return: a BLEU object containing everything you'd want """ fhs = [sys.stdin] + refstreams - sys_len = 0 - ref_len = 0 - - correct = defaultdict(int) - total = defaultdict(int) - # look for already-tokenized sentences tokenized_count = 0 + # Pre-compute segment-level data for BLEU computation. + segmentdata = defaultdict(list) for sentno, lines in enumerate(zip(*fhs)): if lc: lines = [x.lower() for x in lines] @@ -706,35 +703,122 @@ def compute_bleu(instream, refstreams, smooth=0., force=False, lc=False, tokeniz sys.exit(1) output, *refs = [tokenizers[tokenize](x.rstrip()) for x in lines] - + sys_ngrams = extract_ngrams(output) ref_ngrams, closest_diff, closest_len = ref_stats(output, refs) - sys_len += len(output.split()) - ref_len += closest_len + local_correct = defaultdict(int) + local_total = defaultdict(int) - sys_ngrams = extract_ngrams(output) for ngram in sys_ngrams.keys(): n = len(ngram.split()) - total[n] += sys_ngrams[ngram] - correct[n] += min(sys_ngrams[ngram], ref_ngrams.get(ngram, 0)) - - if sum(total) == 0: - logging.error('No input?') - sys.exit(1) + local_total[n] += sys_ngrams[ngram] + local_correct[n] += min(sys_ngrams[ngram], ref_ngrams.get(ngram, 0)) + + segmentdata[sentno].append(len(output.split())) # 0: output_len + segmentdata[sentno].append(closest_diff) # 1: closest_diff (unused) + segmentdata[sentno].append(closest_len) # 2: closest_len + segmentdata[sentno].append(local_total) # 3: local_total + segmentdata[sentno].append(local_correct) # 4: local_correct + + # Based on pre-computed segment-level data, compute BLEU score for input. + # + # This requires seeding the RNG to get reproducible results. For now, + # we simply freeze the seed value as 12345. This can later be changed + # so that is is configurable. If so, the random seed needs to become + # part of the sacreBLEU signature for future reference. + from random import seed, randrange + seed(12345) + + # Size of keys set equals set size + set_size = len(segmentdata.keys()) + + trial_runs = [] + for trial_run in range(bootstrap_trials): + sys_len = 0 + ref_len = 0 + + correct = defaultdict(int) + total = defaultdict(int) + + # First trial run will always use normal test set. This results in + # desired behaviour for bootstrap_trials=1, i.e., a single run. + if trial_run == 0: + input_data = segmentdata.keys() + + # Subsequent trial runs will draw with replacement from keys set. + else: + input_data = (randrange(0, set_size-1) for _ in range(set_size)) - precisions = [0, 0, 0, 0, 0] + # Compute BLEU score for current trial, based on pre-computed data. + for sentno in input_data: + output_len = segmentdata[sentno][0] + closest_diff = segmentdata[sentno][1] + closest_len = segmentdata[sentno][2] + local_total = segmentdata[sentno][3] + local_correct = segmentdata[sentno][4] - for n in range(1, 5): - precisions[n] = max(smooth, 100. * correct[n] / total[n] if total.get(n) > 0 else 0) + sys_len += output_len + ref_len += closest_len - brevity_penalty = 1.0 - if sys_len < ref_len: - brevity_penalty = math.exp(1 - ref_len / sys_len) + for n in local_total.keys(): + total[n] += local_total[n] + correct[n] += local_correct[n] - bleu = 1. * brevity_penalty * math.exp(sum(map(my_log, precisions[1:])) / 4) + if sum(total) == 0: + logging.error('No input?') + sys.exit(1) - return BLEU._make([bleu, precisions[1], precisions[2], precisions[3], precisions[4], brevity_penalty, sys_len, ref_len]) + precisions = [0, 0, 0, 0, 0] + + for n in range(1, 5): + precisions[n] = max(smooth, 100. * correct[n] / total[n] if total.get(n) > 0 else 0) + + brevity_penalty = 1.0 + if sys_len < ref_len: + brevity_penalty = math.exp(1 - ref_len / sys_len) + + bleu = 1. * brevity_penalty * math.exp(sum(map(my_log, precisions[1:])) / 4) + trial_runs.append([bleu, precisions[1], precisions[2], precisions[3], precisions[4], brevity_penalty, sys_len, ref_len]) + + # Compute average BLEU score and component values. + avgBleu = [ + sum(x[0] for x in trial_runs) / len(trial_runs), # bleu + sum(x[1] for x in trial_runs) / len(trial_runs), # precisions[1] + sum(x[2] for x in trial_runs) / len(trial_runs), # precisions[2] + sum(x[3] for x in trial_runs) / len(trial_runs), # precisions[3] + sum(x[4] for x in trial_runs) / len(trial_runs), # precisions[4] + sum(x[5] for x in trial_runs) / len(trial_runs), # brevity_penalty + int(sum(x[6] for x in trial_runs) / len(trial_runs)), # sys_len + int(sum(x[7] for x in trial_runs) / len(trial_runs)), # ref_len + ] + + if bootstrap_trials > 1: + print('Bootstrap trials: n={0}'.format(bootstrap_trials)) + allBleuScores = [x[0] for x in trial_runs] + try: + from numpy import mean, std + from math import sqrt + + # Compute 0.95 confidence interval around BLEU score mean. + xbar = mean(allBleuScores) + s = std(allBleuScores) + sqrtn = sqrt(bootstrap_trials) + z = 1.96 + confidenceInterval = z * s / sqrtn + + except ImportError: + logger.error('Could not import numpy for confidence interval computation') + xbar = sum(allBleuScores) / len(allBleuScores) + confidenceInterval = None + + finally: + if confidenceInterval: + print('Mean BLEU score: {0:.2f} +/- {1:.2f}'.format(xbar, confidenceInterval)) + else: + print('Mean BLEU score: {0:.2f}'.format(xbar)) + + return BLEU._make(avgBleu) def main(): @@ -766,6 +850,8 @@ def main(): help='Suppress informative output.') arg_parser.add_argument('--encoding', '-e', type=str, default='utf-8', help='Open text files with specified encoding (default: %(default)s)') + arg_parser.add_argument('--bootstrap-trials', '-b', type=int, default=1, + help='Compute BLEU based on bootstrap resampling with n trials (default: %(default)d)') arg_parser.add_argument('-V', '--version', action='version', version='%(prog)s {}'.format(VERSION)) args = arg_parser.parse_args() @@ -820,7 +906,7 @@ def main(): logging.warn('You should also pass "--tok zh" when scoring Chinese...') bleu = compute_bleu(sys.stdin, refs, smooth=args.smooth, force=args.force, - lc=args.lc, tokenize=args.tokenize) + lc=args.lc, tokenize=args.tokenize, bootstrap_trials=args.bootstrap_trials) version_str = build_signature(args, len(refs))