diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 232 |
1 files changed, 138 insertions, 94 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 5ab8d8a5d6f1..718fc8b7c1d0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -344,6 +344,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setCondCodeAction(ISD::SETUGT, VT, Expand); setCondCodeAction(ISD::SETUEQ, VT, Expand); setCondCodeAction(ISD::SETUNE, VT, Expand); + + setOperationAction(ISD::FREM, VT, Expand); + setOperationAction(ISD::FPOW, VT, Expand); + setOperationAction(ISD::FPOWI, VT, Expand); + setOperationAction(ISD::FCOS, VT, Expand); + setOperationAction(ISD::FSIN, VT, Expand); + setOperationAction(ISD::FSINCOS, VT, Expand); + setOperationAction(ISD::FEXP, VT, Expand); + setOperationAction(ISD::FEXP2, VT, Expand); + setOperationAction(ISD::FLOG, VT, Expand); + setOperationAction(ISD::FLOG2, VT, Expand); + setOperationAction(ISD::FLOG10, VT, Expand); } } @@ -1135,6 +1147,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); + setOperationAction(ISD::STEP_VECTOR, VT, Custom); setOperationAction(ISD::MULHU, VT, Expand); setOperationAction(ISD::MULHS, VT, Expand); @@ -1167,6 +1180,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64}) { + for (auto InnerVT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, + MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64}) { + // Avoid marking truncating FP stores as legal to prevent the + // DAGCombiner from creating unsupported truncating stores. + setTruncStoreAction(VT, InnerVT, Expand); + } + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); @@ -1387,6 +1407,20 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { // We use EXTRACT_SUBVECTOR to "cast" a scalable vector to a fixed length one. setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + if (VT.isFloatingPoint()) { + setCondCodeAction(ISD::SETO, VT, Expand); + setCondCodeAction(ISD::SETOLT, VT, Expand); + setCondCodeAction(ISD::SETLT, VT, Expand); + setCondCodeAction(ISD::SETOLE, VT, Expand); + setCondCodeAction(ISD::SETLE, VT, Expand); + setCondCodeAction(ISD::SETULT, VT, Expand); + setCondCodeAction(ISD::SETULE, VT, Expand); + setCondCodeAction(ISD::SETUGE, VT, Expand); + setCondCodeAction(ISD::SETUGT, VT, Expand); + setCondCodeAction(ISD::SETUEQ, VT, Expand); + setCondCodeAction(ISD::SETUNE, VT, Expand); + } + // Lower fixed length vector operations to scalable equivalents. setOperationAction(ISD::ABS, VT, Custom); setOperationAction(ISD::ADD, VT, Custom); @@ -1399,6 +1433,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::FABS, VT, Custom); setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::FCEIL, VT, Custom); setOperationAction(ISD::FDIV, VT, Custom); setOperationAction(ISD::FFLOOR, VT, Custom); @@ -1420,6 +1455,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setOperationAction(ISD::MUL, VT, Custom); setOperationAction(ISD::OR, VT, Custom); setOperationAction(ISD::SDIV, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::SIGN_EXTEND, VT, Custom); @@ -1442,6 +1478,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); @@ -2123,6 +2160,24 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( // Lowering Code //===----------------------------------------------------------------------===// +/// isZerosVector - Check whether SDNode N is a zero-filled vector. +static bool isZerosVector(const SDNode *N) { + // Look through a bit convert. + while (N->getOpcode() == ISD::BITCAST) + N = N->getOperand(0).getNode(); + + if (ISD::isConstantSplatVectorAllZeros(N)) + return true; + + if (N->getOpcode() != AArch64ISD::DUP) + return false; + + auto Opnd0 = N->getOperand(0); + auto *CINT = dyn_cast<ConstantSDNode>(Opnd0); + auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0); + return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero()); +} + /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64 /// CC static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) { @@ -3894,9 +3949,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getOperand(2)); } case Intrinsic::aarch64_neon_sdot: - case Intrinsic::aarch64_neon_udot: { - unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT - : AArch64ISD::SDOT; + case Intrinsic::aarch64_neon_udot: + case Intrinsic::aarch64_sve_sdot: + case Intrinsic::aarch64_sve_udot: { + unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot || + IntNo == Intrinsic::aarch64_sve_udot) + ? AArch64ISD::UDOT + : AArch64ISD::SDOT; return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } @@ -4402,6 +4461,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerVECTOR_SHUFFLE(Op, DAG); case ISD::SPLAT_VECTOR: return LowerSPLAT_VECTOR(Op, DAG); + case ISD::STEP_VECTOR: + return LowerSTEP_VECTOR(Op, DAG); case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op, DAG); case ISD::INSERT_SUBVECTOR: @@ -5107,11 +5168,11 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( const Function &CallerF = MF.getFunction(); CallingConv::ID CallerCC = CallerF.getCallingConv(); - // If this function uses the C calling convention but has an SVE signature, - // then it preserves more registers and should assume the SVE_VectorCall CC. + // Functions using the C or Fast calling convention that have an SVE signature + // preserve more registers and should assume the SVE_VectorCall CC. // The check for matching callee-saved regs will determine whether it is // eligible for TCO. - if (CallerCC == CallingConv::C && + if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) && AArch64RegisterInfo::hasSVEArgsOrReturn(&MF)) CallerCC = CallingConv::AArch64_SVE_VectorCall; @@ -5304,7 +5365,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // Check callee args/returns for SVE registers and set calling convention // accordingly. - if (CallConv == CallingConv::C) { + if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) { bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ return Out.VT.isScalableVector(); }); @@ -6994,6 +7055,17 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal); } + if (useSVEForFixedLengthVectorVT(Ty)) { + // FIXME: Ideally this would be the same as above using i1 types, however + // for the moment we can't deal with fixed i1 vector types properly, so + // instead extend the predicate to a result type sized integer vector. + MVT SplatValVT = MVT::getIntegerVT(Ty.getScalarSizeInBits()); + MVT PredVT = MVT::getVectorVT(SplatValVT, Ty.getVectorElementCount()); + SDValue SplatVal = DAG.getSExtOrTrunc(CCVal, DL, SplatValVT); + SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, SplatVal); + return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal); + } + // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select // instruction. if (ISD::isOverflowIntrOpRes(CCVal)) { @@ -9049,6 +9121,20 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, return GenerateTBL(Op, ShuffleMask, DAG); } +SDValue AArch64TargetLowering::LowerSTEP_VECTOR(SDValue Op, + SelectionDAG &DAG) const { + SDLoc dl(Op); + EVT VT = Op.getValueType(); + assert(VT.isScalableVector() && + "Only expect scalable vectors for STEP_VECTOR"); + assert(VT.getScalarType() != MVT::i1 && + "Vectors of i1 types not supported for STEP_VECTOR"); + + SDValue StepVal = Op.getOperand(0); + SDValue Zero = DAG.getConstant(0, dl, StepVal.getValueType()); + return DAG.getNode(AArch64ISD::INDEX_VECTOR, dl, VT, Zero, StepVal); +} + SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); @@ -9663,10 +9749,10 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, } if (i > 0) isOnlyLowElement = false; - if (!isa<ConstantFPSDNode>(V) && !isa<ConstantSDNode>(V)) + if (!isIntOrFPConstant(V)) isConstant = false; - if (isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V)) { + if (isIntOrFPConstant(V)) { ++NumConstantLanes; if (!ConstantValue.getNode()) ConstantValue = V; @@ -9691,7 +9777,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, // Convert BUILD_VECTOR where all elements but the lowest are undef into // SCALAR_TO_VECTOR, except for when we have a single-element constant vector // as SimplifyDemandedBits will just turn that back into BUILD_VECTOR. - if (isOnlyLowElement && !(NumElts == 1 && isa<ConstantSDNode>(Value))) { + if (isOnlyLowElement && !(NumElts == 1 && isIntOrFPConstant(Value))) { LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: only low element used, creating 1 " "SCALAR_TO_VECTOR node\n"); return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Value); @@ -9832,7 +9918,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, for (unsigned i = 0; i < NumElts; ++i) { SDValue V = Op.getOperand(i); SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64); - if (!isa<ConstantSDNode>(V) && !isa<ConstantFPSDNode>(V)) + if (!isIntOrFPConstant(V)) // Note that type legalization likely mucked about with the VT of the // source operand, so we may have to convert it here before inserting. Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Val, V, LaneIdx); @@ -9932,6 +10018,9 @@ SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!"); + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerFixedLengthInsertVectorElt(Op, DAG); + // Check for non-constant or out of range lane. EVT VT = Op.getOperand(0).getValueType(); ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(2)); @@ -9967,8 +10056,11 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!"); - // Check for non-constant or out of range lane. EVT VT = Op.getOperand(0).getValueType(); + if (useSVEForFixedLengthVectorVT(VT)) + return LowerFixedLengthExtractVectorElt(Op, DAG); + + // Check for non-constant or out of range lane. ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(1)); if (!CI || CI->getZExtValue() >= VT.getVectorNumElements()) return SDValue(); @@ -10372,11 +10464,8 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS, SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op, SelectionDAG &DAG) const { - if (Op.getValueType().isScalableVector()) { - if (Op.getOperand(0).getValueType().isFloatingPoint()) - return Op; + if (Op.getValueType().isScalableVector()) return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO); - } if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) return LowerFixedLengthVectorSetccToSVE(Op, DAG); @@ -13280,7 +13369,7 @@ static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) { auto isZeroDot = [](SDValue Dot) { return (Dot.getOpcode() == AArch64ISD::UDOT || Dot.getOpcode() == AArch64ISD::SDOT) && - ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode()); + isZerosVector(Dot.getOperand(0).getNode()); }; if (!isZeroDot(Dot)) std::swap(Dot, A); @@ -13911,78 +14000,7 @@ static SDValue performExtendCombine(SDNode *N, return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD); } - - // This is effectively a custom type legalization for AArch64. - // - // Type legalization will split an extend of a small, legal, type to a larger - // illegal type by first splitting the destination type, often creating - // illegal source types, which then get legalized in isel-confusing ways, - // leading to really terrible codegen. E.g., - // %result = v8i32 sext v8i8 %value - // becomes - // %losrc = extract_subreg %value, ... - // %hisrc = extract_subreg %value, ... - // %lo = v4i32 sext v4i8 %losrc - // %hi = v4i32 sext v4i8 %hisrc - // Things go rapidly downhill from there. - // - // For AArch64, the [sz]ext vector instructions can only go up one element - // size, so we can, e.g., extend from i8 to i16, but to go from i8 to i32 - // take two instructions. - // - // This implies that the most efficient way to do the extend from v8i8 - // to two v4i32 values is to first extend the v8i8 to v8i16, then do - // the normal splitting to happen for the v8i16->v8i32. - - // This is pre-legalization to catch some cases where the default - // type legalization will create ill-tempered code. - if (!DCI.isBeforeLegalizeOps()) - return SDValue(); - - // We're only interested in cleaning things up for non-legal vector types - // here. If both the source and destination are legal, things will just - // work naturally without any fiddling. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT ResVT = N->getValueType(0); - if (!ResVT.isVector() || TLI.isTypeLegal(ResVT)) - return SDValue(); - // If the vector type isn't a simple VT, it's beyond the scope of what - // we're worried about here. Let legalization do its thing and hope for - // the best. - SDValue Src = N->getOperand(0); - EVT SrcVT = Src->getValueType(0); - if (!ResVT.isSimple() || !SrcVT.isSimple()) - return SDValue(); - - // If the source VT is a 64-bit fixed or scalable vector, we can play games - // and get the better results we want. - if (SrcVT.getSizeInBits().getKnownMinSize() != 64) - return SDValue(); - - unsigned SrcEltSize = SrcVT.getScalarSizeInBits(); - ElementCount SrcEC = SrcVT.getVectorElementCount(); - SrcVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize * 2), SrcEC); - SDLoc DL(N); - Src = DAG.getNode(N->getOpcode(), DL, SrcVT, Src); - - // Now split the rest of the operation into two halves, each with a 64 - // bit source. - EVT LoVT, HiVT; - SDValue Lo, Hi; - LoVT = HiVT = ResVT.getHalfNumVectorElementsVT(*DAG.getContext()); - - EVT InNVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getVectorElementType(), - LoVT.getVectorElementCount()); - Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src, - DAG.getConstant(0, DL, MVT::i64)); - Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src, - DAG.getConstant(InNVT.getVectorMinNumElements(), DL, MVT::i64)); - Lo = DAG.getNode(N->getOpcode(), DL, LoVT, Lo); - Hi = DAG.getNode(N->getOpcode(), DL, HiVT, Hi); - - // Now combine the parts back together so we still have a single result - // like the combiner expects. - return DAG.getNode(ISD::CONCAT_VECTORS, DL, ResVT, Lo, Hi); + return SDValue(); } static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St, @@ -15213,7 +15231,8 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { } } - if (N0.getOpcode() != ISD::SETCC || CCVT.getVectorNumElements() != 1 || + if (N0.getOpcode() != ISD::SETCC || + CCVT.getVectorElementCount() != ElementCount::getFixed(1) || CCVT.getVectorElementType() != MVT::i1) return SDValue(); @@ -17221,6 +17240,35 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE( return convertFromScalableVector(DAG, VT, Val); } +SDValue AArch64TargetLowering::LowerFixedLengthExtractVectorElt( + SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + EVT InVT = Op.getOperand(0).getValueType(); + assert(InVT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0)); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Op.getOperand(1)); +} + +SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt( + SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + EVT InVT = Op.getOperand(0).getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0)); + + auto ScalableRes = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Op0, + Op.getOperand(1), Op.getOperand(2)); + + return convertFromScalableVector(DAG, VT, ScalableRes); +} + // Convert vector operation 'Op' to an equivalent predicated operation whereby // the original operation's type is used to construct a suitable predicate. // NOTE: The results for inactive lanes are undefined. @@ -17437,10 +17485,6 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE( assert(Op.getValueType() == InVT.changeTypeToInteger() && "Expected integer result of the same bit length as the inputs!"); - // Expand floating point vector comparisons. - if (InVT.isFloatingPoint()) - return SDValue(); - auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0)); auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1)); auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT); |