rustc_hir_analysis/hir_ty_lowering/
cmse.rsuse rustc_errors::{DiagCtxtHandle, E0781, struct_span_code_err};
use rustc_hir::{self as hir, HirId};
use rustc_middle::bug;
use rustc_middle::ty::layout::LayoutError;
use rustc_middle::ty::{self, ParamEnv, TyCtxt};
use rustc_target::spec::abi;
use crate::errors;
pub(crate) fn validate_cmse_abi<'tcx>(
tcx: TyCtxt<'tcx>,
dcx: DiagCtxtHandle<'_>,
hir_id: HirId,
abi: abi::Abi,
fn_sig: ty::PolyFnSig<'tcx>,
) {
let abi_name = abi.name();
match abi {
abi::Abi::CCmseNonSecureCall => {
let hir_node = tcx.hir_node(hir_id);
let hir::Node::Ty(hir::Ty {
span: bare_fn_span,
kind: hir::TyKind::BareFn(bare_fn_ty),
..
}) = hir_node
else {
let span = match tcx.parent_hir_node(hir_id) {
hir::Node::Item(hir::Item {
kind: hir::ItemKind::ForeignMod { .. },
span,
..
}) => *span,
_ => tcx.hir().span(hir_id),
};
struct_span_code_err!(
tcx.dcx(),
span,
E0781,
"the `\"C-cmse-nonsecure-call\"` ABI is only allowed on function pointers"
)
.emit();
return;
};
match is_valid_cmse_inputs(tcx, fn_sig) {
Ok(Ok(())) => {}
Ok(Err(index)) => {
let span = bare_fn_ty.param_names[index]
.span
.to(bare_fn_ty.decl.inputs[index].span)
.to(bare_fn_ty.decl.inputs.last().unwrap().span);
let plural = bare_fn_ty.param_names.len() - index != 1;
dcx.emit_err(errors::CmseInputsStackSpill { span, plural, abi_name });
}
Err(layout_err) => {
if should_emit_generic_error(abi, layout_err) {
dcx.emit_err(errors::CmseCallGeneric { span: *bare_fn_span });
}
}
}
match is_valid_cmse_output(tcx, fn_sig) {
Ok(true) => {}
Ok(false) => {
let span = bare_fn_ty.decl.output.span();
dcx.emit_err(errors::CmseOutputStackSpill { span, abi_name });
}
Err(layout_err) => {
if should_emit_generic_error(abi, layout_err) {
dcx.emit_err(errors::CmseCallGeneric { span: *bare_fn_span });
}
}
};
}
abi::Abi::CCmseNonSecureEntry => {
let hir_node = tcx.hir_node(hir_id);
let Some(hir::FnSig { decl, span: fn_sig_span, .. }) = hir_node.fn_sig() else {
return;
};
match is_valid_cmse_inputs(tcx, fn_sig) {
Ok(Ok(())) => {}
Ok(Err(index)) => {
let span = decl.inputs[index].span.to(decl.inputs.last().unwrap().span);
let plural = decl.inputs.len() - index != 1;
dcx.emit_err(errors::CmseInputsStackSpill { span, plural, abi_name });
}
Err(layout_err) => {
if should_emit_generic_error(abi, layout_err) {
dcx.emit_err(errors::CmseEntryGeneric { span: *fn_sig_span });
}
}
}
match is_valid_cmse_output(tcx, fn_sig) {
Ok(true) => {}
Ok(false) => {
let span = decl.output.span();
dcx.emit_err(errors::CmseOutputStackSpill { span, abi_name });
}
Err(layout_err) => {
if should_emit_generic_error(abi, layout_err) {
dcx.emit_err(errors::CmseEntryGeneric { span: *fn_sig_span });
}
}
};
}
_ => (),
}
}
fn is_valid_cmse_inputs<'tcx>(
tcx: TyCtxt<'tcx>,
fn_sig: ty::PolyFnSig<'tcx>,
) -> Result<Result<(), usize>, &'tcx LayoutError<'tcx>> {
let mut span = None;
let mut accum = 0u64;
let fn_sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
for (index, ty) in fn_sig.inputs().iter().enumerate() {
let layout = tcx.layout_of(ParamEnv::reveal_all().and(*ty))?;
let align = layout.layout.align().abi.bytes();
let size = layout.layout.size().bytes();
accum += size;
accum = accum.next_multiple_of(Ord::max(4, align));
if accum > 16 {
span = span.or(Some(index));
}
}
match span {
None => Ok(Ok(())),
Some(span) => Ok(Err(span)),
}
}
fn is_valid_cmse_output<'tcx>(
tcx: TyCtxt<'tcx>,
fn_sig: ty::PolyFnSig<'tcx>,
) -> Result<bool, &'tcx LayoutError<'tcx>> {
let fn_sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
let mut ret_ty = fn_sig.output();
let layout = tcx.layout_of(ParamEnv::reveal_all().and(ret_ty))?;
let size = layout.layout.size().bytes();
if size <= 4 {
return Ok(true);
} else if size > 8 {
return Ok(false);
}
'outer: loop {
let ty::Adt(adt_def, args) = ret_ty.kind() else {
break;
};
if !adt_def.repr().transparent() {
break;
}
for variant_def in adt_def.variants() {
for field_def in variant_def.fields.iter() {
let ty = field_def.ty(tcx, args);
let layout = tcx.layout_of(ParamEnv::reveal_all().and(ty))?;
if !layout.layout.is_1zst() {
ret_ty = ty;
continue 'outer;
}
}
}
}
Ok(ret_ty == tcx.types.i64 || ret_ty == tcx.types.u64 || ret_ty == tcx.types.f64)
}
fn should_emit_generic_error<'tcx>(abi: abi::Abi, layout_err: &'tcx LayoutError<'tcx>) -> bool {
use LayoutError::*;
match layout_err {
Unknown(ty) => {
match abi {
abi::Abi::CCmseNonSecureCall => {
!ty.is_impl_trait()
}
abi::Abi::CCmseNonSecureEntry => true,
_ => bug!("invalid ABI: {abi}"),
}
}
SizeOverflow(..) | NormalizationFailure(..) | ReferencesError(..) | Cycle(..) => {
false }
}
}