CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[AutoParallel] Polish dist tensor design #56368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Polish dist tensor design #56368
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
||
const Place& DistTensor::place() const { | ||
inline void check_defined(const DistTensor& dist_tensor, | ||
std::string method_hint) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是可以直接const&,在cc文件里用inline修饰有什么作用吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用string是为了避免多一次拷贝构造,inline应该能发挥其在编译时的作用
const TensorDistAttr& dist_attr); | ||
|
||
/// \brief Construct a dist tensor based dense tensor. | ||
/// \param value The global dense tensor of the current tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个构造函数应该传入的是local_tensor吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
@@ -219,9 +209,6 @@ void InitDistTensorWithNumpyValue(TensorObject* self, | |||
"Place should be one of " | |||
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/CustomPlace")); | |||
} | |||
|
|||
// TODO(dev): dist_tensor meta is not equal to dense tensor meta | |||
dist_tensor_ptr->set_meta(impl_ptr->meta()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里貌似有点问题,InitDistTensorWithNumpyValue接收的是一个全局视角的numpy数据,通过把DistTensor里的DenseTensor指针取出来,直接用numpy直接替换,没办法经过reshard,设置正确的局部DenseTensor。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,很关键,这里需要调整下设计,不能留这个漏洞给开发者,改为禁止开发者在不传入正确DistAttr信息的前提下无法可读修改DistTensor中DenseTensor的值,确保分布式属性与数据一致
dense_temp->meta(), | ||
std::make_shared< | ||
phi::distributed::auto_parallel::TensorDistAttr>()); | ||
*dense_temp, phi::distributed::auto_parallel::TensorDistAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里得到的DenseTensor是global还是local的呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
回传起始梯度目前是global的,但传入的TensorDistAttr还需要再完善一下,切分可以在反向第一个API中进行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for const_cast
* polish dist teensor design * adjust constructor * polish details * polish details design * fix compile error * refactor init tensor impl * fix reshard test * polish details * add unittest for coverage
PR types
Function optimization
PR changes
Others
Description
Pcard-73145
[AutoParallel] Polish dist tensor design
简化一下动半基础数据结构DistTensor的设计,便于后续实现其他架构功能
成员方面:
DDim dims_
:原先使用DenseTensorMeta,但目前从需求来看,DistTensor只需要比DenseTensor多持有一个global shape,因此改为直接使用DDimTensorDistAttr dist_attr_
:原先使用shared_ptr包装,但目前从需求来看,dist_attr_并不需要和其他DistTensor共享管理,一般为各自持有,因此改为直接使用对象DenseTensor value_
:原先使用unique_ptr,一般成员若使用unique_ptr,其创建销毁均单独在类内进行,但DistTensor有基于外部DenseTensor实例化的需求,且DenseTensor内含有shared_ptr的allocation成员,做不到完全的unique,因此也改为对象,便于编码,也减少堆操作构造函数方面,简化为以下几个:
本次修改后DistTensor限制了直接获取mutable的DenseTensor value进行修改,以避免dist_attr和value值不一致引入bug,但是目前有两处需要获取mutable的DenseTensor value值去使用,其使用是安全的,因此使用了const_cast