CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Hackathon 6th No.11】为 Paddle 新增 bernoulli_ API - part #64252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/tensor/random.py
Outdated
zeros_mask = x < p | ||
x.masked_fill_(ones_mask, 1.0) | ||
out = x.masked_fill_(zeros_mask, 0.0) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x.masked_fill_(ones_mask, 1.0)
x.masked_fill_(zeros_mask, 0.0)
return x
这样吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
def set_np_compare_func(self): | ||
self.np_compare = np.array_equal | ||
|
||
def inplace_api_processing(self, var): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也测一下类方法调用吧,x.bernoulli_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加
return var.bernoulli_(self.p) | ||
|
||
def non_inplace_api_processing(self, var): | ||
return paddle.bernoulli(paddle.zeros(self.shape) + self.p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改成 paddle.full(self.shape, p) 更好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
def non_inplace_api_processing(self, var): | ||
return paddle.bernoulli(paddle.zeros(self.shape) + self.p) | ||
|
||
def test_inplace_api(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测一下反向,随机数反向应该全0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加反向测试
non_inplace_var.numpy().var(), inplace_var.numpy().var(), atol=0.01 | ||
) | ||
|
||
def test_inplace_class_method(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测一下反向,随机数反向应该全0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加反向测试
@NKNaN 提交了吗 |
已提交 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
Args: | ||
x(Tensor): The input tensor to be filled with random values. | ||
p (float|Tensor, optional): The success probability parameter of the output Tensor's bernoulli distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we support scenarios where p
can broadcast to x
? if so, we should add description p
must be broadcast to x
here, and add unit tests to check error if cannot broadcast. If not, it needs to be clarified that p can only be float and Tensor which must have same shape with x
, and add unit tests to check error if p.shape
!= x.shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could support the scenarios. Let me update the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
""" | ||
This is the inplace version of api ``bernoulli``, which returns a Tensor filled | ||
with random values sampled from a bernoulli distribution. The output Tensor will | ||
be inplaced with input ``x``. Please refer to :ref:`api_tensor_bernoulli`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
be inplaced with input ``x``. Please refer to :ref:`api_tensor_bernoulli`. | |
be inplaced with input ``x``. Please refer to :ref:`api_paddle_bernoulli`. |
>>> import paddle | ||
>>> x = paddle.randn([3, 4]) | ||
>>> x.bernoulli_() | ||
>>> # doctest: +SKIP('random check') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输出能固定吗?通过固定随机数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文档问题重新提一个 PR 改吧,这个先 approve
…64252) * add bernoulli_ * update * update docs * fix typo * update code and test * add test case for backward * add test broadcast error and update docs
PR Category
User Experience
PR Types
New features
Description
新增 bernoulli_ API