下面具体讲一下在 BN 算子开发中用到的 CGRATensor_ArithOp_V_VuV_mul_loop (bf16 *src, bf16 *dst, bf16 *unit, int rnd, int src_elem_num, int unit_elem_num, int full_src_elem_num, int full_unit_elem_num).
后端 IR 使用的是 MLIR,继承 Dialect,定义了许多 Operations, Attributes, Types.
def Tx8be_Dialect : Dialect { let name ="tx8be"; let summary ="A low-level dialect for tx8 backend specification"; let cppNamespace ="::tx8be_mir::tx8be"; let useDefaultAttributePrinterParser =1;}
def Tx8be_ParallelAttr : Tx8be_Attr<"Parallel","parallel_attr">{ let summary ="Structure of parallel information."; let parameters =(ins
"ParallelModeAttr":$parallel,"bool":$is_dp_inner,// dp dimension is in the inner, otherwise tp
"i32":$dp_dim_x,// data parallel dimension at x axis
"i32":$dp_dim_y,// data parallel dimension at y axis
"i32":$dp_dim_z,// data parallel dimension at z axis
"i32":$tp_dim_x,// tensor parallel dimension at x axis
"i32":$tp_dim_y,// tensor parallel dimension at y axis
"i32":$tp_dim_z,// tensor parallel dimension at z axis
"bool":$sharding_is_given,// true: is given, false: is not
"::mlir::DenseI32ArrayAttr":$shape_spatial_sharding // Shape split info
); let cppNamespace ="::tx8be_mir::tx8be"; let assemblyFormat ="`<` struct($params) 1";}
dev_attr 属性包含
imm_size,也就是用到的这个辅助空间的大小。
mem_layout 也就是数据的存储数据的排布。
multi_buf_en 指是否使用 double buffer.
out_shape_buf_idx 指的是输出使用第几个缓冲区。
temporal_mem_slice 是单个 Tile 每次处理的数据大小。
def Tx8be_DevAttr : Tx8be_Attr<"Dev","dev_attr">{ let summary ="Structure of op parameters on device."; let parameters =(ins
"uint64_t":$imm_size,// Output memory addr offset
"LayoutModeAttr":$mem_layout,// Layout
"bool":$multi_buf_en,// for double buffering
"int32_t":$multi_buf_num,// for double buffering
"mlir::DenseI64ArrayAttr":$out_shape_buf_idx,// index for dynamic shape buffer on runtime
"mlir::DenseI64ArrayAttr":$temporal_mem_slice,// for compute local buffer size
"int32_t":$source_type,// Software pipeline stage
"int64_t":$imm_addr,"mlir::DenseI64ArrayAttr":$mem_addr // use array for multibuffer
); let cppNamespace ="::tx8be_mir::tx8be"; let assemblyFormat ="`<` struct($params) `>`";}
MemScopeMode 用于描述数据存储在哪里。
def Tx8be_MemScopeMode : I32EnumAttr<"MemScopeMode","Specify the memory scope",[ I32EnumAttrCase<"DDR",0>, I32EnumAttrCase<"SPM",1>, I32EnumAttrCase<"3DDRAM",2>]>{ let genSpecializedAttr =0; let cppNamespace ="::tx8be_mir::tx8be";}
def Tx8be_BatchNorm_InferenceOp : Tx8be_Op<"BatchNorm_Inference",[DeclareOpInterfaceMethods<oplibinterface>, DeclareOpInterfaceMethods<ShardingInterface>, DeclareOpInterfaceMethods<ComputeInterface>]{ let summary ="BatchNorm inference"; let description =[{ Normalizes the operand tensor across all dimensions except for the c dimension
and produce a result tensor.}]; let arguments =(ins
AnyTensor:$input, AnyTensor:$scale, AnyTensor:$offset, AnyTensor:$mean, AnyTensor:$variance, DefaultValueOptionalStrAttr<StrAttr,"Unknown">:$layout_str,// The following are backend parameters
OptionalAttr<Tx8be_ParallelAttr>:$chip_parallel, OptionalAttr<Tx8be_ParallelAttr>:$tile_parallel, OptionalAttr<Tx8be_DevAttr>:$dev_info
); let results =(outs AnyTensor:$output);}
def ShapeInferenceOpInterface : OpInterface<"ShapeInferenceOpInterface">{ let description =[{}]; let cppNamespace ="::tx8be_mlir"; let methods =[ InterfaceMethod<[{}],/*retType=*/"mlir::LogicalResult",/*methodName=*/"inferShapes",// method name
/*args=*/(ins "DynamicShapeParam":$shapeParam)>, InterfaceMethod<[{}],/*retType=*/"mlir::LogicalResult",/*methodName=*/"inferLayout",// method name
/*args=*/(ins)>];}
std::vector<ShardingSplitParam>tx8be::BatchNorm_InferenceOp::tileShardingSplit(ShardingSplitParam¶m){autoshape=getOutput().getType().getShape();ASSERT(shape.size()==param.outSharding.size()&&shape.size()==param.outSplit.size());int32_tshape_size=shape.size();std::vector<ShardingSplitParam>result;result.emplace_back(param);// input
for(int32_ti=0;i<shape_size-1;++i){if(result[0].outSharding.size()>0&&result[0].outSharding[i]!=1){// can only shard in dim C
result[0].outSharding.clear();}if(result[0].outSplit.size()>0&&result[0].outSplit[shape_size-1]!=1){// can only split except dim C
result[0].outSplit.clear();}}ShardingSplitParamparamMean;// scale/shift/mean/variance
if(result[0].outSharding.size()>0){paramMean.outSharding=result[0].outSharding;}paramMean.outSplit=std::vector<int32_t>(shape_size,1);// shape is 1x1x1xC,split must be (1, 1, 1, 1)
ShardingSplitParamparamVar=paramMean;ShardingSplitParamparamScale=paramMean;ShardingSplitParamparamShift=paramMean;result.emplace_back(paramScale);result.emplace_back(paramShift);result.emplace_back(paramMean);result.emplace_back(paramVar);returnresult;}
typedefstructL_SHAPE{int32_tshape_whole[MAX_SHAPE_DIM];// the whole shape
int32_tshape_start[MAX_SHAPE_DIM];// start idx of the shape slice
int32_tshape_slice[MAX_SHAPE_DIM];// length of the shape slice
int32_tshape_real[MAX_SHAPE_DIM];// real length of the shape slice
int32_tdim;// dimension of the shape
}L_SHAPE;typedefstructG_SHAPE{int32_tspatial_start[MAX_SHAPE_DIM];// [start, end]
int32_tspatial_end[MAX_SHAPE_DIM];int32_tdynamic_offset[MAX_SHAPE_DIM];int32_tshape[MAX_SHAPE_DIM];int32_tdim;int32_tdone;// done for dma load finish
int32_tbatch_offset[MAX_SHAPE_DIM];}G_SHAPE;typedefstructTSR{Data_Formatformat;uint64_taddr;L_SHAPE*shape;}TSR;
BatchNorm Design
对于非 fp32 类型数据 (以 fp16 为例) 计算过程与空间分配如下图所示。
类型转换成 fp32: gatherScatter.
调用 fp16->fp32 函数进行转换。
循环计算 x-Mean (因为对 in 的 NHW 维度进行了 split),结果存入 imm_a.
Varience 自加 epsilon(1e-6).
Varience 进行 rsqrt 操作。
Varience 与 x-Mean 进行循环乘。
循环乘 scale.
循环加 shift.
fp32 转回 f16.
gatherScatter 到 out 处。
Batchnorm Computation Flow
这里需要注意的是 shift(1, 1, 1, C) 和归一化后的 x(N, H, W, C) 相乘的时候,这时候就用到了之前所说的 VuV_mul 和 VuV_mul_loop 指令。
当 C <= 32 时,一个 batch 内的数据排布如下 (以 (4x112x2x30) x (1x1x1x30) 为例),此时我们在 batch 维度上循环调用 VuV_mul 指令就可以。
ComputeInterface 这个接口主要是每个 OP 通过 onednn 得到 CPU 代码。或者计算比较简单的 OP 如果在 onednn 的接口中没有找到对应的计算,也可以在 compute 接口中手写当前 OP 的 CPU 实现的 C++代码。最终生成结果会用来检验算子正确性。
def ComputeInterface : OpInterface<"ComputeInterface">{ let description =[]; let cppNamespace ="::tx8be_mlir"; let methods =[ InterfaceMethod</*desc=*/[],/*retType=*/"::mlir::LogicalResult",/*methodName=*/"compute",/*args=*/(ins "ComputeParam&":$param)>,];}
// const can be directly considered to be aligned
// constop(dim < 2) -> channelNorm -> constop
structConstChannelNormErase:publicmlir::OpRewritePattern<txbe::ConstantOp>{ConstChannelNormErase(mlir::MLIRContext*context,/*benefit=*/1){}mlir::LogicalResultmatchAndRewrite(txbe::ConstantOpop,mlir::PatternRewriter&rewriter)constoverride{// If const has multi user, can not erase
if(!op->hasOneUse())returnmlir::failure();autouser=*op->getUsers().begin();if(!isa<txbe::ChannelNormOp>(user))returnmlir::failure();autoshape=op->getResult(0).getType().dyn_cast<mlir::ShapedType>().getShape();if(shape.size()>1)returnmlir::failure();llvm::SmallVector<Operation*>userVec;userVec.insert(userVec.end(),user->getUsers().begin(),user->getUsers().end());for(autochannelNormUser:userVec){channelNormUser->replaceUsesOfWith(user->getResult(0),op->getResult(0));}// set align=true
setDevInfoWithLayout(op->getContext(),op->getLayoutStr().str(),true);if(user->use_empty())rewriter.eraseOp(user);returnsuccess();}};
// A pass to erase redundant channel normalization operations
structRedundantChannelNormErase:publicmlir::OpRewritePattern<tx8be::ChannelNormOp>{RedundantChannelNormErase(mlir::MLIRContext*context):OpRewritePattern<tx8be::ChannelNormOp>(context,/*benefit=*/1){}mlir::LogicalResultmatchAndRewrite(tx8be::ChannelNormOpop,mlir::PatternRewriter&rewriter)constoverride{// Define the input operation and its defining operation
// def represents the operation that generates the op input data
autodef=op.getInput().getDefiningOp();// Check if the defining operation is a ConstantOp and has more than one result
if(isa<tx8be::ConstantOp>(def)&&(def->getNumResults()>1)){returnmlir::failure();// Fail if conditions are not met
}// Get the size in bits of the input shape
autosize=op.getInput().getType().cast<ShapedType>().getSizeInBits();Operation*sameOp=nullptr;// Pointer to a potentially redundant operation
// Iterate over all users of the defining operation
for(autouser:def->getUsers()){if(user==op){// Skip if the user is the current operation
continue;}if(isa<tx8be::ChannelNormOp>(user)){// Check if the user is another ChannelNormOp
sameOp=user;// Store the redundant operation
break;}}if(!sameOp)returnmlir::failure();// Fail if no redundant operation is found
// Replace all uses of the redundant operation with the current operation's results
op->replaceAllUsesWith(sameOp->getOpResults());if(op->use_empty()){// Erase the current operation if it has no more uses
rewriter.eraseOp(op);}returnsuccess();// Return success if the rewrite is completed
}};
structParamInfo{std::vector<uint8_t>*data_ptr;// const value
std::set<int32_t>chip_id;// which chips has this const, -1 indicates all chip has the same param.
uint32_tlabel;// Indicates whether the data is assigned to a certain chip_id.
};// class ConstContainer {
classConstContainer{public:ConstContainer();virtual~ConstContainer();// some public functions
private:std::map<uint32_t,std::vector<ParamInfo>>_data;std::map<uint32_t,std::map<int32_t,uint64_t>>oidToSize;std::map<uint32_t,uint32_t>oidToNid;};
voidMoveConstantPass::runOnOperation(){// create constant container
createConstContainer();// get module op
ModuleOpmodule=getOperation();// Set pattern
MLIRContext*context=&getContext();RewritePatternSetpatterns(context);patterns.insert<ConstantToLoadConst>(context);constFrozenRewritePatternSetfrozen_patterns=FrozenRewritePatternSet(std::move(patterns));// Set config
GreedyRewriteConfigconfig;config.useTopDownTraversal=true;for(autofunc:module.getOps<func::FuncOp>()){Region&body=func.getBody();if(failed(applyPatternsAndFoldGreedily(body,frozen_patterns,config))){llvm::errs()<<"Failed when move const in main graph.\n";signalPassFailure();}}for(autosubgraph:module.getOps<tx8be::SubgraphOp>()){Region&body=subgraph.getBody();if(failed(applyPatternsAndFoldGreedily(body,frozen_patterns,config))){llvm::errs()<<"Failed when move const in subgraph.\n";signalPassFailure();}}TileInfotinfo=get_tileinfo(module);updateConstContainer(tinfo.tile_num);// update id by thresholdSize
updateLdConstop();}
structConstantToLoadConst:publicmlir::OpRewritePattern<tx8be::ConstantOp>{ConstantToLoadConst(mlir::MLIRContext*context):OpRewritePattern<tx8be::ConstantOp>(context,/*benefit=*/){}mlir::LogicalResultmatchAndRewrite(tx8be::ConstantOpop,mlir::PatternRewriter&rewriter)constoverride{uint32_tid=0;// Store constant data to constant container
// ...
// Determine if this constant operation needs an explicit load instruction.
boolneedLoad=false;autov=op.getOutput();// Iterate over all operations that use this output value.
for(autouser_op:v.getUsers()){// Get the argument index of the user op that corresponds to our output value.
int32_targ_idx=getArgumentIdx(user_op,v);// Assert that the user operation implements our custom OpLibInterface.
ASSERT(llvm::dyn_cast<tx8be::OpLibInterface>(user_op));// Get the library attributes for this user operation.
autoopAttr=llvm::dyn_cast<tx8be::OpLibInterface>(user_op).queryOpAttr();// Skip if the user is a TupleOp, which might have special handling.
if(isa<tx8be::TupleOp>(user_op)){continue;}if(opAttr.needLoad&(1<<arg_idx)){// Check if the 'needLoad' attribute
needLoad=true;}else{ASSERT(needLoad==false);}}// Set attributes
// ...
// Safely iterate over the users. This is important because we are modifying the use-list inside the loop.
for(auto&use:llvm::make_early_inc_range(op.getOutput().getUses())){Operation*userOp=use.getOwner();// Create the new, hardware-specific LoadConst operation.
txbe::LoadConstOpnewLoadConst=rewriter.create<txbe::LoadConstOp>(op.getLoc(),op.getOutput().getType(),ValueRange{},attrs);if(!needLoad){// this constant does not need an explicit load...
// Get a builder to set attributes.
OpBuilderbuilder(newLoadConst.getContext());// Set a 'bypasscodegen' attribute, signaling special handling for this op in later stages.
newLoadConst.getOperation()->setAttr("bypasscodegen",builder.getBoolAttr(true));}// Set the layout string attribute on the new LoadConst op.
newLoadConst->setAttr("layout_str",op->getAttr("layout_str"));// CRITICAL STEP: Rewire the user's operand to point to the result of the new LoadConst op.
userOp->setOperand(use.getOperandNumber(),newLoadConst);}// After all uses have been replaced, erase the original, now-dead ConstantOp.
rewriter.eraseOp(op);returnsuccess();}}
voidConstNormPass::runOnOperation(){ModuleOpmodule=getOperation();func::FuncOpmainGraphFunc=getMainFuncOp(module);SmallVector<Operation*>deletedChannelnorm;// Walk the main function to find a specific pattern: LoadConst -> ChannelNorm.
mainGraphFunc.walk([&](Operation*constOp){if(isa<tx8be::LoadConstOp>(constOp)){std::unordered_set<Operation*>users;users.insert(constOp->getUsers().begin(),constOp->getUsers().end());boolflag=false;// Check if any user is a ChannelNormOp.
for(autouser:users){if(isa<tx8be::ChannelNormOp>(user)){flag=true;break;}}// If the LoadConst has exactly one user, and that user is a ChannelNormOp,
// mark the ChannelNormOp for deletion.
if(flag&&users.size()==1){for(autoit:users){// The erase logic is commented out, maybe handled by constChannelNormErase or done later.
deletedChannelnorm.push_back(it);}}}});// Erase all the marked ChannelNormOps. This is done in a separate loop
// to avoid iterator invalidation issues.
for(autoop:deletedChannelnorm){op->erase();}// Set up and run a nested pass pipeline.
OpPassManagerthisPM(this->getOpName().value());// This pipeline will only apply to LoadConstOp operations inside functions.
OpPassManager&loadConstOpPM=thisPM.nest<func::FuncOp>().nest<tx8be::LoadConstOp>();// Add the ConstNormDoPass to the pipeline.
loadConstOpPM.addPass(std::make_unique<ConstNormDoPass>());// Run the newly constructed pipeline on the module.
autoresult=this->runPipeline(thisPM,getOperation());// After the pipeline, run a final cleanup/consistency check function.
processMultiUse(module);// change unpack input0 qweight shape after ConstNormDoPass. (Original comment)
// This logic is likely inside the runOnOperation() method of ConstNormDoPass.
mainGraphFunc.walk([&](Operation*constOp){if(isa<tx8be::LoadConstOp>(constOp)){// Collect all users of this LoadConstOp.
std::unordered_set<Operation*>users;users.insert(constOp->getUsers().begin(),constOp->getUsers().end());// Check if any user is an UnpackOp.
boolflag=false;for(autouser:users){if(isa<tx8be::UnpackOp>(user)){flag=true;break;}}// If there is exactly one user, and it's an UnpackOp...
if(flag&&users.size()==1){for(autoit:users){// This check seems to ensure we are modifying the correct operand.
if(constOp->getResult(0)==it->getOperand(0)){// Get the original shape and type.
llvm::SmallVector<int64_t,6>oShape;autotype=constOp->getResult(0).getType().cast<ShapedType>();autoshape=type.getShape();// Apply the shape transformation: e.g., for unpacking packed data.
oShape.push_back((int32_t)shape[0]/4);oShape.push_back((int32_t)shape[1]*4);// Create a new tensor type with the new shape.
autooType=mlir::RankedTensorType::get(oShape,type.getElementType());// Update the type of the LoadConstOp's result in-place.
constOp->getResult(0).setType(oType);}}}}});}
// This function erases a ChannelNormOp by bypassing it and updating the source constant's layout.
voidconstChannelNormErase(tx8be::ChannelNormOpop){// Find the defining operation of the ChannelNorm's operand, which should be a LoadConstOp.
autodefOp=llvm::dyn_cast_or_null<tx8be::LoadConstOp>(op->getOperand(0).getDefiningOp());// If the source is not a LoadConstOp, do nothing.
if(!defOp)return;// Collect all users of the ChannelNormOp's result.
llvm::SmallVector<Operation*>userVec;userVec.insert(userVec.end(),op->getUsers().begin(),op->getUsers().end());for(autouser:userVec){// Replace all uses of the ChannelNormOp's result with the result of the LoadConstOp..
user->replaceUsesOfWith(op->getResult(0),op->getOperand(0));}// After bypassing, the layout of the source constant might need to be adjusted
// to reflect the transformation that the ChannelNormOp was supposed to perform.
// set const layout to cx mode
autodev_layout=getDevInfoLayoutMode(defOp);autoalign_dev_layout=get_aligned_layout((LAYOUT_MODE)dev_layout);setDevInfoWithLayout(defOp->getContext(),defOp,static_cast<tx8be::LayoutMode>(align_dev_layout));}
// This function processes multi-use constants to ensure their layouts are consistent.
voidConstNormPass::processMultiUse(ModuleOpmodule){func::FuncOpmainGraphFunc=getMainFuncOp(module);// When a const is used by multiple users, multiple loadconsts will be generated,
// but only one loadconst will have its layout set. The others will be skipped.
// We need to go over them uniformly.
// First, find all previous useless constant ops.
// Group all LoadConstOp instances by their underlying constant data ID (const_map_id).
std::unordered_map<int32_t,std::vector<mlir::Operation*>>allconst;mainGraphFunc.walk([&](Operation*constOp){if(isa<tx8be::LoadConstOp>(constOp)){autocOp=llvm::dyn_cast<tx8be::LoadConstOp>(constOp);uint32_tt_map_id=cOp.getConstMapId();allconst[t_map_id].emplace_back(constOp);}});// Based on duplication, find if the layout needs to be changed to cx.
// Check if there is also a Cx with the same layout.
// Iterate over each group of LoadConstOps that share the same data.
for(auto&kv:allconst){if(kv.second.size()>1){// Process only if there are multiple users.
// Assume the layout of the first user is the correct one.
autolayout=(LAYOUT_MODE)getDevInfoLayoutMode(kv.second.front());// This loop is for validation, checking if layouts are inconsistent.
for(autoop:kv.second){autolayout2=(LAYOUT_MODE)getDevInfoLayoutMode(op);if(is_cx_layout(layout2)!=ALIGN_NOT){layout=layout2;break;}}// Force all LoadConstOps in this group to have the same, correct layout.
for(autoop:kv.second){autoctx=op->getContext();ASSERT(op->hasAttr("dev_info")&&"Must have dev_info!");setDevInfoWithLayout(ctx,op,(tx8be::LayoutMode)layout);}}}}
voidGroupPatternPass::runOnOperation(){TFUNC_SCOPE(DEBUG);autosubgraphOp=getOperation();// Get the current operation (e.g., a function) the pass is running on.
PatternManagermanager;// A manager to hold graph rewriting information.
Automationaca(&manager);// Custom 'Automation' class for pattern matching logic.
autominfo=getModuleConfig(getModuleByOp(getOperation()));std::stringpath="";autotemp=path!=""?getPatternsFromFile(path)// Load patterns from a file if path is specified.
:(patternConfigMap.at(static_cast<GroupPatternMode>(minfo.opt_group)));// Otherwise, load from a pre-defined map using a config key.
TLOG(INFO)<<"[GroupPatternPass] config id: "<<minfo.opt_group;aca.insertPatterns(temp);// Insert the loaded patterns into the Automation engine. This is the starting point for building the matching structure.
TLOG(INFO)<<"[Automation]: \n"<<printTree(aca.root);aca.search(subgraphOp);// Execute the search for all patterns on the given subgraph. (search function code is not provided but its role is clear).
manager.applyAll();autogroups=manager.getGroups();// Retrieve the groups of operations that were matched.
manager.show();autonewGroups=createGroups(subgraphOp,groups);// Create new group structures from the matched results.
for(autogroup:newGroups){sortTopologically(group->getBlock());// Topologically sort the operations within each new group to maintain data dependencies.
}}
voidAutomation::insertPatterns(std::map<std::vector<TX8BE_OPS>,int>patterns){std::vector<std::vector<TX8BE_OPS>>tempPatterns;for(autoit:patterns){// Iterate through each pattern from the input map.
autotemp=processPattern(it.first);// Pre-process the pattern. This can expand one pattern into many.
for(autop:temp){// For each of the generated concrete patterns...
insertPattern(p,it.second);// ...insert it into the main data structure (the Trie).
}}}
structTrieNode{TrieNode(TX8BE_OPSid):id(id){}// Constructor to initialize the node with an operation ID.
TX8BE_OPSid;// The operation (Op) type this node represents. This is the 'character' in our sequence.
std::vector<int>output;// Stores the integer IDs of the patterns that end at this node. A non-empty vector indicates a valid pattern match.
std::vector<TX8BE_OPS>pattern;// Stores the complete operator sequence for the pattern that ends here.
std::unordered_map<TX8BE_OPS,NodePtr>children;// A map from an operation type to the next node in the trie. `NodePtr` is likely a shared_ptr or unique_ptr to another TrieNode.
};voidAutomation::insertPattern(conststd::vector<TX8BE_OPS>pattern,intindex){patterns_.push_back(pattern);// Store the raw pattern vector.
autonode=root;// Start from the root of the Trie.
for(autoop:pattern){// Iterate through each operation in the pattern sequence.
if(node->children.find(op)==node->children.end()){// If a path for this operation does not exist...
node->children[op]=std::make_shared<TrieNode>(op);// ...create a new node in the Trie.
}node=node->children[op];// Move to the next node in the Trie.
}node->pattern=pattern;// At the end of the pattern, mark this node as a terminal node by storing the full pattern.
node->output.push_back(index);// Store the original pattern index/ID at this terminal node.
}
NodePtrAutomation::searchOp(NodePtrparentNode,Operation*op){autoopId=getOpId(op);// Get the enumerated ID (e.g., TX8BE_OPS::CONV) for the current MLIR operation.
if(isRealOp(op)&&parentNode->children.find(opId)==parentNode->children.end()){// If the current op is a "real" operation (not a terminator, etc.) but cannot be found in the children of the parent Trie node, it's a mismatch.
// This 'if' block seems to be an early exit for a specific case, possibly redundant with the final return.
}if(parentNode->children.find(opId)!=parentNode->children.end()){// If a path exists in the Trie for the current operation `opId`. This is a potential match.
// If the current op matches, continue downwards
autocurrentNode=parentNode->children[opId];// Move to the matched Trie node.
autotempNode=currentNode;// `tempNode` will store the longest match found so far starting from this path.
// --- Query Operation Attributes and Users ---
autoqueryInterface=llvm::dyn_cast<tx8e_mlir::OpLibInterface>(op);// Get a specific interface from the operation for querying attributes.
autoneedStore=queryInterface.queryOpAttr().needStore;// Check an attribute, e.g., if the op's result needs to be stored.
llvm::SmallSet<Operation*,1>users;// Find all direct users of the current operation's result.
for(autouser:op->getUsers()){users.insert(user);}autosortedUsers=manager_->sortOps(users);// Sort the users, likely topologically or based on some priority.
// --- Recursively Search Through Users ---
for(autouser:sortedUsers){if(!isRealOp(user))continue;// Skip non-essential ops.
autointerface=llvm::dyn_cast<tx8e_mlir::OpLibInterface>(user);autoneedLoad=interface.queryOpAttr().needLoad;if(!needStore&&needLoad)continue;// Skip paths with certain attribute mismatches (e.g., store-load dependency).
// Recursively call searchOp for the user operation, starting from the current Trie node.
autoterminalNode=searchOp(currentNode,user);// --- Update Best Match ---
if(!terminalNode->output.empty()&&!tempNode->output.empty()){// If both the previous best match (`tempNode`) and the new match (`terminalNode`) are valid patterns...
// Compare priority, take the one with the highest priority as the current node pattern)
if(terminalNode->output.front()>tempNode->output.front()){/...update`tempNode`tothenewoneifithasahigherpriority(assumingtheintIDrepresentspriority).tempNode=terminalNode;}}elseif(!terminalNode->output.empty()){// If `tempNode` was not a valid pattern end, but `terminalNode` is, update it.
tempNode=terminalNode;}}// TFOOTER(TRACE)
returntempNode;// Return the node corresponding to the longest/best pattern found from this point.
}// Indicates parent node cannot match current op, return parent node)
returnparentNode;// If no match was found for `opId` in the Trie, return the original `parentNode`.
}
voidAutomation::search(tx8e::SubgraphOpsubgraph){// k: the starting operation of a matched pattern
// v: the type/ID of the matched pattern
std::map<Operation*,int>result;manager_->initDefsMap(subgraph);// Initialize manager with definition information from the subgraph.
subgraph->walk([&](Operation*op){// First pass: walk through the subgraph to gather metadata.
manager_->opOrder_.insert(op);// Record the sequential order of all operations.
manager_->opIndexMap_[op]=index++;// Assign a unique index to each operation.
});// Second pass: walk through the subgraph again to perform the actual pattern matching.
subgraph->walk([&](Operation*op){// Skip the return operation of the subgraph as it's not part of a computational pattern.
if(isa<tx8e::SubgraphOp,tx8e::SubgraphReturnOp>(op)){returnWalkResult::skip();// In newer MLIR, this might be `return;`. Skips processing this op's children.
}autopattern=std::make_shared<Pattern>(op);// Create a Pattern object, representing a potential match starting at `op`.
manager_->patterns_.push_back(pattern);// Add this potential pattern to the manager's list.
manager_->patternMap_[op]=pattern;// Map the operation `op` to its corresponding Pattern object.
// terminalNode 就是最后匹配到的一个Node (terminalNode is the final matched Node)
// This is the main call to the recursive search function, starting from the Trie root for each `op`.
autoterminalNode=searchOp(root,op);// If the Node has an output, it means a match was found. If multiple matches exist, they are replaced based on priority during the search phase
// The final result is a match for the highest-priority pattern
if(!terminalNode->output.empty()){// Check if the search returned a valid pattern-terminating node.
// If a match was found, update the Pattern object with the results from the terminal node.
pattern->setPattern(terminalNode->output.front(),terminalNode->pattern);// Record the result: map the starting operation `op` to the matched pattern's ID.
result[op]=terminalNode->output.front();}returnWalkResult::advance();// Proceed to the next operation in the walk.
});}
会遍历一个计算 subGraph 中的所有 OP. 对于每一个通过筛选的普通计算操作,会调用 createSingleGroup 函数来为其创建一个专属的 GroupOp.
createSingleGroup 会检查原始 OP 的所有输入。如果输入来自另一个计算操作,那么这个输入就会成为新 GroupOp 的输入。如果输入是 LoadConstOp,则被视为这个分组的内部依赖,而不是外部输入。原始 op 的所有输出会直接成为新 GroupOp 的输出。
新的 GroupOp 拥有上一步定义的输入和输出。原始的操作 op 和它的常量依赖 (dependencies) 被移动到这个新创建的 GroupOp 内部。最后,修改原始操作 OP 的连接关系,使其在分组内部能够正确地接收输入并产生输出。伪代码如下
for op in subGraph.ops:
// 检查操作的类型
if op == (GroupOp || ReturnOp || LoadConstOp || NoneOp):
continue
createSingleGroup(op)
------------------------------------
createSingleGroup(op):
for pre_op in op.inputsOp:
// 判断前置操作是否为“加载常量”或“空操作”
if pre_op == (LoadConstOp || NoneOp):
// 如果是,则将其添加到依赖项 (dependencies) 集合中
dependencies.add(pre_op)
else:
// 如果是其他普通操作,则将其结果添加到新分组的输入 (groupInput) 中
groupInput.add(pre_op.result)
for result in op.results: // 遍历当前操作的所有输出结果
// 将这些结果添加到新分组的输出 (groupOutput) 中
groupOutput.add(result)
// 使用收集好的输入和输出创建一个新的 GroupOp (分组操作)
create GroupOp(groupInput, groupOutput)
// 将依赖项 (如常量) 移动到新分组的末尾 (或内部)
move dependencies to group end
// 将原始操作 op 本身也移动到新分组的末尾 (或内部)
move op to group end
// 修改原始操作 op 的输入和输出,使其在新分组内部正确连接
change op input and output
voidGroupLdStPass::runOnOperation(){subgraph.walk([&](tx8e::GroupOpg_op){// ...
// For each group's input, insert a load. If used by multiple ops, multiple loads are inserted
for(autov:g_op.getBody().front().getArguments()){// Iterate over each input argument of the group.
Operation*pre_op=getValidDefiningOp(v);// Find the operation that produces this input.
// ...
std::map<Operation*,int32_t>usersLoad;// A map to store users that need to load this input.
for(autouserOp:v.getUsers()){// Find all users of this input argument.
// ...
// Check if the user needs a 'load' based on its attributes.
if((!opAttr.needLoad&&(1<<arg_idx))){continue;}// If a load is needed, record the user and its argument index.
usersLoad.insert(std::make_pair(userOp,arg_idx));// This block handles complex layout logic for Add/Sub/Mul/Div ops.
// It seems to ensure that if one input to 'add' is rank1 tensor, the other is also handled correctly,
// potentially by forcing a specific layout (`LayoutMode::Cx`).
if(isa<tx8e::AddOp,tx8e::SubOp,tx8e::DivOp,tx8e::MulOp>(userOp)){// ... [复杂布局逻辑]
}if(usersLoad.size()!=0){// there are users that require a load operation.
std::vector<NamedAttribute>tmp_attrs;// ... [构建LoadVarOp的属性]
// Create the Load operation.
autold=builder.create<tx8e::LoadVarOp>(g_op.getLoc(),v.getType(),v,tmp_attrs);// ... [设置动态shape属性]
// For each user that needs the load...
for(autouserOp:usersLoad){// ...replace its use of the original input `v` with the result of the new `Load` operation `ld`.
userOp.first->replaceUsesOfWith(v,ld.getOutput());}}}// For each group's output, insert a store
builder.setInsertionPointToEnd(&block);// Set the insertion point to the end of the group's body.
Operation*g_return=g_op.getBody().front().getTerminator();// Get the return operation of the group.
for(inti=0;i<g_return->getNumOperands();++i){// Iterate over each output of the group.
autovalue=g_return->getOperand(i);autopre_op=value.getDefiningOp();// Find the operation inside the group that produces this output.
// ...
// Check if this output value needs to be stored for external users.
if(!(llvm::dyn_cast<tx8e::OpLibInterface>(pre_op)).queryOpAttr().needStore&&(1<<i)){continue;}// ... [构建StoreVarOp的属性]
// Create the Store operation.
autost=builder.create<tx8e::StoreVarOp>(g_op.getLoc(),value.getType(),value,st_attrs);// ... [设置动态shape属性]
// Update the group's return instruction to return the result of the store op.
g_return->setOperand(i,st.getOutput());}g_return->moveBefore(gBlock,block.end());// Move the return instruction (not standard MLIR, might be custom logic).
updateIR(g_op);// Update the IR of the group op.
}});}
// Defines a function to perform a simple mapping of groups.
voidsimpleGroupMapping(ModuleOpmodule){// Get x and y dimension from the module's attributes.
// These attributes are likely defined globally for the entire compilation.
uint32_tx_dim=module->getAttrOfType<IntegerAttr>(tx8e::ModuleAttr::TileDx).getInt();uint32_ty_dim=module->getAttrOfType<IntegerAttr>(tx8e::ModuleAttr::TileDy).getInt();// Create an OpBuilder instance, which is a helper to create/modify MLIR operations.
OpBuilderbuilder(module.getContext());// Get the 'main' function from the module.
func::FuncOpmain=module.getMainFuncOp();// Get the first block (entry block) of the main function.
auto&main_block=main.getBody().front();for(auto&inner:main_block.getOperations()){// Iterate over all operations within the main function's body
if(isa<tx8e::CallOp>(inner)){// The module's main function contains CallOps. This implies an indirect call to a subgraph.
// Find the subgraph definition ('SubraphOp') using the symbol name from the CallOp.
tx8e::SubgraphOpsg=module.lookupSymbol<tx8e::SubgraphOp>(llvm::dyn_cast<tx8e::CallOp>(inner).getCallee());// Walk through the operations inside the called subgraph.
// We are looking for the 'GroupOp' which is the actual unit of computation.
sg.walk([&](tx8e::GroupOpgop){// Set a 'dev_region' attribute on the located GroupOp.
setDevRegionAttr(builder,module.getContext(),gop.getOperation(),x_dim,y_dim);});}if(isa<tx8e::GroupOp>(inner)){// The module's main function directly contains GroupOps.
// Directly set the 'dev_region' attribute on the GroupOp found in the main function.
setDevRegionAttr(builder,module.getContext(),llvm::dyn_cast<tx8e::GroupOp>(inner).getOperation(),x_dim,y_dim);}}}voidGroupMappingPass::runOnOperation(){// It will operate on the entire ModuleOp.
automodule=getOperation();simpleGroupMapping(module);}
/**
* @struct BufferLabel
* @brief A unique identifier for a memory buffer.
*
* This struct links a buffer to a specific MLIR Value and tracks whether it's
* a special "immediate" buffer. It's used as a key in maps to associate
* MLIR Values with their buffer metadata.
*/structBufferLabel{// The MLIR Value that this buffer represents, typically a tensor produced
// by an operation.
mlir::Valuev;// A flag indicating if this buffer holds a special "immediate" value.
// Immediate values might be treated differently during allocation (e.g.,
// small constants or internal scratchpads for an op).
boolisImm{false};/**
* @brief Equality operator to compare two labels.
*
* Two labels are considered equal if they refer to the same MLIR Value
* and have the same 'isImm' status. This is necessary for using
* BufferLabel as a key in std::map or std::unordered_map.
*/booloperator==(constBufferLabel&other)const{return(v==other.v)&&(isImm==other.isImm);}};
ValueBuffer 包含单个缓冲区所需的所有元数据,包括其标识、生存期、大小和最终内存位置。
/**
* @struct ValueBuffer
* @brief Represents the metadata for a single memory buffer, including its
* lifetime, size, and allocation information.
*/structValueBuffer{// The unique identifier for this buffer.
BufferLabellabel;// Represents the starting point of the buffer's lifetime (inclusive),
// measured in pipeline cycles. After memory allocation, this field may be
// repurposed to store the starting memory address.
int64_tstart;// Represents the ending point of the buffer's lifetime (inclusive),
// measured in pipeline cycles. After memory allocation, this field may be
// repurposed to store the ending memory address.
int64_tend;// The total size of this buffer in bytes, as required by its tensor shape.
int64_tallSize{0};// Size of an intermediate/temporary buffer that an operator might need
// internally. This is often allocated contiguously with the main output
// buffer. For example, the final output address would be 'offset + immSize'.
int64_timmSize{0};// The final memory offset assigned to this buffer in the scratchpad memory.
// This value is determined by the final memory allocation pass.
int64_toffset{0};/**
* @brief Less-than operator, used for sorting ValueBuffer objects.
*
* The active implementation sorts buffers primarily by their lifetime start
* time. This is a common strategy for greedy "first fit" style memory
* allocation algorithms. The commented-out code shows an alternative
* strategy of sorting by buffer size.
*/booloperator<(constValueBuffer&other)const{// return this->allSize < other.allSize; // Alternative sorting by size
returnthis->start<=other.start;}};