CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN] Reconstruct shape_analysis #63790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] Reconstruct shape_analysis #63790
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
657e9e5
to
d6ddd3c
Compare
std::unordered_map<symbol::DimExpr, symbol::DimExpr>; | ||
DimExprSubstitutionPattern substitution_pattern_; | ||
private: | ||
ModuleOp m_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModuleOp m_
没有使用的地方了,可以删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
|
||
void InferSymbolicShapeContext::SetStaticShapeForValue(Value val) { | ||
auto type_info = val.type().dyn_cast<pir::DenseTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里对Value加一些强制检查,有问题及早抛出。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
1f636f8
to
6c9225d
Compare
has_prev_op = true; | ||
if (operand.impl() && !context_.HasShapeOrDataForValue(operand)) { | ||
if (!operand.defining_op()) { | ||
SetStaticShapeForValue(operand); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以先判断下operand是不是静态shape, if嵌套比较深的话可以改成if-continue的形式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,检查直接在SetStaticShapeForValue函数里了
// The implementation is based on shape constraint ir. | ||
// The implementation is based on shape constraint ir. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重复注释,现在这句注释没有价值,可以在下个PR里重新写一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis); | ||
pir::Operation* op, pir::InferSymbolicShapeContext* infer_context); | ||
|
||
bool IsStaticShape(const Value& value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数声明放在.cc文件里吧,加到头文件影响范围会扩大
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加VectorType的支持,作为公共函数以供其它地方使用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type
* reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type
add int4_1 int4_2 FLAGS_logging_pir_py_code (PaddlePaddle#63981) * FLAGS_logging_pir_py_code * FLAGS_logging_pir_py_code_dir --------- Co-authored-by: jiahy0825 <jiahongyu@baidu.com> [Cleanup] Remove Flake8 config in `.editorconfig` (PaddlePaddle#64027) 【PIR Dist Op Reg No.19】 reg pull_box_sparse (PaddlePaddle#62982) * fix * fix * fix * fix * fix * fix * add test * add * fix * fix * add out * fix * codestyle * fix * fix backward * merge [Dy2St][PIR] Hold backward program in GradNode (PaddlePaddle#63694) Co-authored-by: xiongkun <xiongkun03@baidu.com> Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com> split test.cmake: add new test_cases.cmake (PaddlePaddle#64007) [PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009) * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt option for WITH_CPP_TEST (PaddlePaddle#63896) * option for WITH_CPP_TEST * fix * Fix * Fix [PIR] Fix `attributes_num` of `SliceArrayOp` (PaddlePaddle#64013) [Dy2St] Use `full_graph=True` outside dy2st uts (part1) (PaddlePaddle#64058) [Dy2St] Use `full_graph=True` outside dy2st uts (part2) (PaddlePaddle#64059) fix typo (PaddlePaddle#64060) Co-authored-by: jiahy0825 <jiahongyu@baidu.com> update (PaddlePaddle#64042) Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (PaddlePaddle#63819) * Fix * Fix * Fix Clean lookup_table_v2_op.h lookup_table_v2_op.cu (PaddlePaddle#64020) * Fix * ci refine GetTensorListFromArgs (PaddlePaddle#64045) Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (PaddlePaddle#64050) This reverts commit 88b1a6e. [Prim][PIR] support floor_divide op forward in prim pir (PaddlePaddle#64023) * floor-div-dev * update test [CINN] Reconstruct shape_analysis (PaddlePaddle#63790) * reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type [XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003) * [XPU] fix segmentfault caused by setting fix_seed_offset on XPU * cast attention_mask to float32 when necessary fix merge bug (PaddlePaddle#64069) 【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (PaddlePaddle#64064) [Prim][VJP]support autogen to remove unused composite in .yaml (PaddlePaddle#64054) * support autogen to remove unused composite in .yaml * fix bug [PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (PaddlePaddle#64063) [Dy2St] Use `full_graph=True` outside dy2st uts (part3) (PaddlePaddle#64066) [PIR save/load] Open more tests for paddle.save and paddle.load (PaddlePaddle#64044) * open more tests for paddle.save and paddle.load * fix API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (PaddlePaddle#63881) * update group_norm * update trt plugin * update trt plugin * fix trt plugin * fix trt plugin * fix test * fix test * fix ci windows inference * update kernel function names and add v2 test * fix * fix fp16 test Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (PaddlePaddle#64049) This reverts commit 2134ead. Clean paddle/fluid/operators/fused/attention_layer_norm.h (PaddlePaddle#64051) * Fix * Fix Replace operators::math to phi::math in fluid/operators (PaddlePaddle#63854) [CINN]Clean usless loop_reorder_aligment tactic (PaddlePaddle#63998) * [CINN]Clean usless loop_reorder_aligment tactic * fix source 【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (PaddlePaddle#63783) * Fix * Fix * Fix * Fix * Fix 【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (PaddlePaddle#63929) 【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (PaddlePaddle#63936) * Fix * Fix * Fix * Fix [CINN] Remove useless log (PaddlePaddle#64052) [pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958) * add jit load.train * modify backward program lost * modify * combine eval and train * modify 8 case of jit.save.load * modify jit_save_load case * rename jit_save_load * change name all * modify timeout * modify new case * modify TestJitSaveLoadMultiMethods * modify cpu tensor no holder bug Flashattention support qkvpacked and varlen (PaddlePaddle#63289) * Flashattention support qkvpacked and varlen * fix codestyle * fix codestyle * FlashAttention kvReduceGQA Performance Optimization * Fix problem with windows * code clean * update third_party/flashattn * update errormsg and docs * update api * update doc * update doctest * update doc, test=document_fix * update doc, test=document_fix * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> * update doc --------- Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> 【PIR Dist Op Reg No.20】 reg global_gather (PaddlePaddle#63867) * reg global_gather * reg global_gather * reg_global_gather * fix * fix * fix * fix conflict * fix conflict * Update ops_api_gen.py * Update ops_api_gen.py Fix backward program kwargs error when process inplace value (PaddlePaddle#63939) 【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (PaddlePaddle#63880) * support kwargs for recompute when use_reentrant == True * recover third party merge main lint delete printf change flash attn version
* reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type
PR Category
CINN
PR Types
Not User Facing
Description
Pcard-67164
This PR reconstructs the shape_analysis class to address the interface invocation risks brought by the lazy update mechanism introduced in PR63367. Specifically, during the global graph shape inference, the parameters passed to the infersymbolic function have been changed to the InferSymbolicShapeContext member variable within the shape_analysis class. This ensures that only the non-lazy mode get function is called in all inference scenarios, and the lazy mode get function is activated only when used in the pass. Additionally, the reconstruction of the shape_analysis class exposes more shape symbolic inference, and the PR also completes the following fixes:
It is worth noting that this PR is still an intermediate state within the new symbolic inference update mechanism. Future work includes: