Skip to content

Commit fbecb23

Browse files
committed
Impl typeck & MIR lowering for argument splatting
1 parent 125d25d commit fbecb23

67 files changed

Lines changed: 1771 additions & 192 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

compiler/rustc_ast/src/ast.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3055,9 +3055,30 @@ impl FnDecl {
30553055
pub fn has_self(&self) -> bool {
30563056
self.inputs.get(0).is_some_and(Param::is_self)
30573057
}
3058+
30583059
pub fn c_variadic(&self) -> bool {
30593060
self.inputs.last().is_some_and(|arg| matches!(arg.ty.kind, TyKind::CVarArgs))
30603061
}
3062+
3063+
/// The marker index for "no splatted arguments".
3064+
/// Must have the same value as `FnSigKind::NO_SPLATTED_ARG_INDEX` and `FnDeclFlags::NO_SPLATTED_ARG_INDEX`.
3065+
pub const NO_SPLATTED_ARG_INDEX: u16 = u16::MAX;
3066+
3067+
/// Returns a splatted argument index, if any are present.
3068+
pub fn splatted(&self) -> Option<u16> {
3069+
self.inputs.iter().enumerate().find_map(|(index, arg)| {
3070+
if index == Self::NO_SPLATTED_ARG_INDEX as usize {
3071+
// AST validation has already checked the splatted argument index is valid, so just
3072+
// ignore invalid indexes here.
3073+
None
3074+
} else {
3075+
arg.attrs
3076+
.iter()
3077+
.any(|attr| attr.has_name(sym::splat))
3078+
.then_some(u16::try_from(index).unwrap())
3079+
}
3080+
})
3081+
}
30613082
}
30623083

30633084
/// Is the trait definition an auto trait?

compiler/rustc_ast_lowering/src/delegation.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl<'hir, R: ResolverAstLoweringExt<'hir>> LoweringContext<'_, 'hir, R> {
141141

142142
let is_method = self.is_method(sig_id, span);
143143

144-
let (param_count, c_variadic) = self.param_count(sig_id);
144+
let (param_count, c_variadic, splatted) = self.param_count(sig_id);
145145

146146
let mut generics = self.uplift_delegation_generics(delegation, sig_id, item_id);
147147

@@ -153,8 +153,14 @@ impl<'hir, R: ResolverAstLoweringExt<'hir>> LoweringContext<'_, 'hir, R> {
153153
span,
154154
);
155155

156-
let decl =
157-
self.lower_delegation_decl(sig_id, param_count, c_variadic, span, &generics);
156+
let decl = self.lower_delegation_decl(
157+
sig_id,
158+
param_count,
159+
c_variadic,
160+
splatted,
161+
span,
162+
&generics,
163+
);
158164

159165
let sig = self.lower_delegation_sig(sig_id, decl, span);
160166
let ident = self.lower_ident(delegation.ident);
@@ -268,17 +274,18 @@ impl<'hir, R: ResolverAstLoweringExt<'hir>> LoweringContext<'_, 'hir, R> {
268274
self.resolver.get_partial_res(node_id).and_then(|r| r.expect_full_res().opt_def_id())
269275
}
270276

271-
// Function parameter count, including C variadic `...` if present.
272-
fn param_count(&self, def_id: DefId) -> (usize, bool /*c_variadic*/) {
277+
// Function parameter count, including C variadic `...` and `#[splat]` if present.
278+
fn param_count(&self, def_id: DefId) -> (usize, bool /*c_variadic*/, Option<u16> /*splatted*/) {
273279
let sig = self.tcx.fn_sig(def_id).skip_binder().skip_binder();
274-
(sig.inputs().len() + usize::from(sig.c_variadic()), sig.c_variadic())
280+
(sig.inputs().len() + usize::from(sig.c_variadic()), sig.c_variadic(), sig.splatted())
275281
}
276282

277283
fn lower_delegation_decl(
278284
&mut self,
279285
sig_id: DefId,
280286
param_count: usize,
281287
c_variadic: bool,
288+
splatted: Option<u16>,
282289
span: Span,
283290
generics: &GenericsGenerationResults<'hir>,
284291
) -> &'hir hir::FnDecl<'hir> {
@@ -311,7 +318,9 @@ impl<'hir, R: ResolverAstLoweringExt<'hir>> LoweringContext<'_, 'hir, R> {
311318
output: hir::FnRetTy::Return(output),
312319
fn_decl_kind: FnDeclFlags::default()
313320
.set_lifetime_elision_allowed(true)
314-
.set_c_variadic(c_variadic),
321+
.set_c_variadic(c_variadic)
322+
.set_splatted(splatted, inputs.len())
323+
.unwrap(),
315324
})
316325
}
317326

compiler/rustc_ast_lowering/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1844,12 +1844,15 @@ impl<'hir, R: ResolverAstLoweringExt<'hir>> LoweringContext<'_, 'hir, R> {
18441844
coro: Option<CoroutineKind>,
18451845
) -> &'hir hir::FnDecl<'hir> {
18461846
let c_variadic = decl.c_variadic();
1847+
let mut splatted = decl.splatted();
18471848

18481849
// Skip the `...` (`CVarArgs`) trailing arguments from the AST,
18491850
// as they are not explicit in HIR/Ty function signatures.
18501851
// (instead, the `c_variadic` flag is set to `true`)
18511852
let mut inputs = &decl.inputs[..];
18521853
if decl.c_variadic() {
1854+
// Splat + variadic errors in AST validation, so just ignore one of them here.
1855+
splatted = None;
18531856
inputs = &inputs[..inputs.len() - 1];
18541857
}
18551858
let inputs = self.arena.alloc_from_iter(inputs.iter().map(|param| {
@@ -1937,7 +1940,9 @@ impl<'hir, R: ResolverAstLoweringExt<'hir>> LoweringContext<'_, 'hir, R> {
19371940
}
19381941
}))
19391942
.set_lifetime_elision_allowed(self.resolver.lifetime_elision_allowed(fn_node_id))
1940-
.set_c_variadic(c_variadic);
1943+
.set_c_variadic(c_variadic)
1944+
.set_splatted(splatted, inputs.len())
1945+
.unwrap();
19411946

19421947
self.arena.alloc(hir::FnDecl { inputs, output, fn_decl_kind })
19431948
}

compiler/rustc_ast_passes/src/ast_validation.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ impl<'a> AstValidator<'a> {
350350

351351
fn check_fn_decl(&self, fn_decl: &FnDecl, self_semantic: SelfSemantic) {
352352
self.check_decl_num_args(fn_decl);
353-
self.check_decl_cvariadic_pos(fn_decl);
353+
let c_variadic_span = self.check_decl_cvariadic_pos(fn_decl);
354+
self.check_decl_splatting(fn_decl, c_variadic_span);
354355
self.check_decl_attrs(fn_decl);
355356
self.check_decl_self_param(fn_decl, self_semantic);
356357
}
@@ -368,17 +369,68 @@ impl<'a> AstValidator<'a> {
368369
/// Emits an error if a function declaration has a variadic parameter in the
369370
/// beginning or middle of parameter list.
370371
/// Example: `fn foo(..., x: i32)` will emit an error.
371-
fn check_decl_cvariadic_pos(&self, fn_decl: &FnDecl) {
372+
/// Returns true if a C-variadic parameter is found.
373+
fn check_decl_cvariadic_pos(&self, fn_decl: &FnDecl) -> Option<Span> {
374+
let mut c_variadic_span = None;
375+
372376
match &*fn_decl.inputs {
373377
[ps @ .., _] => {
374378
for Param { ty, span, .. } in ps {
375379
if let TyKind::CVarArgs = ty.kind {
380+
c_variadic_span = Some(*span);
376381
self.dcx().emit_err(errors::FnParamCVarArgsNotLast { span: *span });
377382
}
378383
}
379384
}
380385
_ => {}
381386
}
387+
388+
if let Some(Param { ty, span, .. }) = &fn_decl.inputs.last() {
389+
if let TyKind::CVarArgs = ty.kind {
390+
c_variadic_span = Some(*span);
391+
}
392+
}
393+
394+
c_variadic_span
395+
}
396+
397+
/// Emits an error if a function declaration has more than one splatted argument, with a
398+
/// C-variadic parameter, or a splat at an unsupported index (for performance).
399+
/// Example: `fn foo(#[splat] x: (), #[splat] y: ())` will emit an error.
400+
fn check_decl_splatting(&self, fn_decl: &FnDecl, c_variadic_span: Option<Span>) {
401+
let (splatted_arg_indexes, mut splatted_spans): (Vec<u16>, Vec<Span>) = fn_decl
402+
.inputs
403+
.iter()
404+
.enumerate()
405+
.filter_map(|(index, arg)| {
406+
arg.attrs
407+
.iter()
408+
.any(|attr| attr.has_name(sym::splat))
409+
.then_some((u16::try_from(index).unwrap(), arg.span))
410+
})
411+
.unzip();
412+
413+
// A splatted argument at the "no splatted" marker index is not supported (this is an
414+
// unlikely edge case).
415+
if let (Some(&splatted_arg_index), Some(&splatted_span)) =
416+
(splatted_arg_indexes.last(), splatted_spans.last())
417+
&& splatted_arg_index == FnDecl::NO_SPLATTED_ARG_INDEX
418+
{
419+
self.dcx()
420+
.emit_err(errors::InvalidSplattedArg { splatted_arg_index, span: splatted_span });
421+
}
422+
423+
// Multiple splatted arguments are invalid: we can't know which arguments go in each splat.
424+
if splatted_arg_indexes.len() > 1 {
425+
self.dcx().emit_err(errors::DuplicateSplattedArgs { spans: splatted_spans.clone() });
426+
}
427+
428+
if let Some(c_variadic_span) = c_variadic_span
429+
&& !splatted_spans.is_empty()
430+
{
431+
splatted_spans.push(c_variadic_span);
432+
self.dcx().emit_err(errors::CVarArgsAndSplat { spans: splatted_spans });
433+
}
382434
}
383435

384436
fn check_decl_attrs(&self, fn_decl: &FnDecl) {

compiler/rustc_ast_passes/src/errors.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,33 @@ pub(crate) struct FnParamCVarArgsNotLast {
124124
pub span: Span,
125125
}
126126

127+
#[derive(Diagnostic)]
128+
#[diag("`#[splat]` is not supported on argument index {$splatted_arg_index}")]
129+
#[help("remove `#[splat]`, or use it on an argument closer to the start of the argument list")]
130+
pub(crate) struct InvalidSplattedArg {
131+
pub splatted_arg_index: u16,
132+
133+
#[primary_span]
134+
#[label("`#[splat]` is not supported here")]
135+
pub span: Span,
136+
}
137+
138+
#[derive(Diagnostic)]
139+
#[diag("multiple `#[splat]`s are not allowed in the same function")]
140+
#[help("remove `#[splat]` from all but one argument")]
141+
pub(crate) struct DuplicateSplattedArgs {
142+
#[primary_span]
143+
pub spans: Vec<Span>,
144+
}
145+
146+
#[derive(Diagnostic)]
147+
#[diag("`...` and `#[splat]` are not allowed in the same function")]
148+
#[help("remove `#[splat]` or remove `...`")]
149+
pub(crate) struct CVarArgsAndSplat {
150+
#[primary_span]
151+
pub spans: Vec<Span>,
152+
}
153+
127154
#[derive(Diagnostic)]
128155
#[diag("documentation comments cannot be applied to function parameters")]
129156
pub(crate) struct FnParamDocComment {

compiler/rustc_borrowck/src/diagnostics/region_errors.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ use rustc_middle::bug;
1515
use rustc_middle::hir::place::PlaceBase;
1616
use rustc_middle::mir::{AnnotationSource, ConstraintCategory, ReturnConstraint};
1717
use rustc_middle::ty::{
18-
self, FnSigKind, GenericArgs, Region, RegionVid, Ty, TyCtxt, TypeFoldable, TypeVisitor,
19-
fold_regions,
18+
self, GenericArgs, Region, RegionVid, Ty, TyCtxt, TypeFoldable, TypeVisitor, fold_regions,
2019
};
2120
use rustc_span::{Ident, Span, kw};
2221
use rustc_trait_selection::error_reporting::InferCtxtErrorExt;
@@ -1085,8 +1084,8 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
10851084
}
10861085

10871086
// Build a new closure where the return type is an owned value, instead of a ref.
1088-
let fn_sig_kind =
1089-
FnSigKind::default().set_safe(true).set_c_variadic(liberated_sig.c_variadic());
1087+
// The new closure is safe, but otherwise has the same ABI, splat, and c-variadic.
1088+
let fn_sig_kind = liberated_sig.fn_sig_kind.set_safe(true);
10901089
let closure_sig_as_fn_ptr_ty = Ty::new_fn_ptr(
10911090
tcx,
10921091
ty::Binder::dummy(tcx.mk_fn_sig(

compiler/rustc_codegen_cranelift/src/abi/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ pub(crate) fn codegen_fn_prelude<'tcx>(fx: &mut FunctionCx<'_, '_, 'tcx>, start_
265265
.map(|local| {
266266
let arg_ty = fx.monomorphize(fx.mir.local_decls[local].ty);
267267

268+
// FIXME(splat): un-tuple splatted arguments in codegen, for performance
268269
// Adapted from https://github.com/rust-lang/rust/blob/145155dc96757002c7b2e9de8489416e2fdbbd57/src/librustc_codegen_llvm/mir/mod.rs#L442-L482
269270
if Some(local) == fx.mir.spread_arg {
270271
// This argument (e.g. the last argument in the "rust-call" ABI)

compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ fn push_debuginfo_type_name<'tcx>(
375375
output.push_str("fn(");
376376
}
377377

378+
// FIXME(splat): should debuginfo be de-tupled in the callee (and caller)?
378379
if !sig.inputs().is_empty() {
379380
for &parameter_type in sig.inputs() {
380381
push_debuginfo_type_name(tcx, parameter_type, true, output, visited);

compiler/rustc_const_eval/src/const_eval/type_info.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use rustc_ast::Mutability;
77
use rustc_hir::LangItem;
88
use rustc_middle::span_bug;
99
use rustc_middle::ty::layout::TyAndLayout;
10-
use rustc_middle::ty::{self, Const, FnHeader, FnSigTys, ScalarInt, Ty, TyCtxt};
10+
use rustc_middle::ty::{self, Const, FnHeader, FnSigKind, FnSigTys, ScalarInt, Ty, TyCtxt};
1111
use rustc_span::{Symbol, sym};
1212

1313
use crate::const_eval::CompileTimeMachine;
@@ -465,6 +465,15 @@ impl<'tcx> InterpCx<'tcx, CompileTimeMachine<'tcx>> {
465465
sym::variadic => {
466466
self.write_scalar(Scalar::from_bool(fn_sig_kind.c_variadic()), &field_place)?;
467467
}
468+
sym::splat => {
469+
self.write_scalar(
470+
// Use the same encoding as FnSigKind.splatted
471+
Scalar::from_u16(
472+
fn_sig_kind.splatted().unwrap_or(FnSigKind::NO_SPLATTED_ARG_INDEX),
473+
),
474+
&field_place,
475+
)?;
476+
}
468477
other => span_bug!(self.tcx.def_span(field.did), "unimplemented field {other}"),
469478
}
470479
}

0 commit comments

Comments
 (0)