-
Notifications
You must be signed in to change notification settings - Fork 5.9k
new api trunc, test=develop #33371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new api trunc, test=develop #33371
Changes from 4 commits
35aebd5
cd716f2
33cc89b
b577296
9a6272b
de107a3
58b19cb
b7b5803
f76b32e
8158a58
feaeaa4
b7f18f9
2783560
df79e49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/operators/trunc_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class TruncOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext *ctx) const override { | ||
| OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc"); | ||
| OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc"); | ||
| auto input_dims = ctx->GetInputDim("X"); | ||
| ctx->SetOutputDim("Out", input_dims); | ||
| ctx->ShareLoD("X", /*->*/ "Out"); | ||
| } | ||
| }; | ||
|
|
||
| class TruncOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| void Make() override { | ||
| AddInput("X", "(Tensor), The input tensor of trunc op."); | ||
| AddOutput("Out", "(Tensor), The output tensor of trunc op."); | ||
| AddComment(R"DOC( | ||
| Trunc Operator. | ||
| Returns a new tensor with the truncated integer values of input. | ||
| $$out = trunc(x)$$ | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| class TruncGradOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext *ctx) const override { | ||
| OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
| framework::GradVarName("Out"), "TruncGrad"); | ||
| OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", | ||
| framework::GradVarName("X"), "TruncGrad"); | ||
|
|
||
| auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); | ||
| ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class TruncGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
| public: | ||
| using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
|
||
| void Apply(GradOpPtr<T> retv) const override { | ||
| retv->SetType("trunc_grad"); | ||
| retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
| retv->SetAttrMap(this->Attrs()); | ||
| retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, | ||
| ops::TruncGradOpMaker<paddle::framework::OpDesc>, | ||
| ops::TruncGradOpMaker<paddle::imperative::OpBase>); | ||
|
|
||
| REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel<float>, ops::TruncKernel<double>, | ||
| ops::TruncKernel<int>, ops::TruncKernel<int64_t>); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel<float>, | ||
| ops::TruncGradKernel<double>, ops::TruncGradKernel<int>, | ||
| ops::TruncGradKernel<int64_t>); | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,101 @@ | ||||||||||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||||||||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
| you may not use this file except in compliance with the License. | ||||||||||
| You may obtain a copy of the License at | ||||||||||
| http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
| Unless required by applicable law or agreed to in writing, software | ||||||||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
| See the License for the specific language governing permissions and | ||||||||||
| limitations under the License. */ | ||||||||||
|
|
||||||||||
| #include "paddle/fluid/operators/trunc_op.h" | ||||||||||
|
|
||||||||||
| namespace paddle { | ||||||||||
| namespace operators { | ||||||||||
|
|
||||||||||
| template <typename T> | ||||||||||
| class truncFunctor { | ||||||||||
|
||||||||||
| class truncFunctor { | |
| class TruncFunctor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x -> x, plz refer to google code stype: https://google.github.io/styleguide/cppguide.html.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; } | |
| CUDA_KERNEL_LOOP(index, N) { | |
| dx[index] = static_cast<T>(0.0); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theads = platform::PADDLE_CUDA_NUM_THREADS;
blocks = (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are blank lines in copyright.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, thanks!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it changed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now it is changed, done, thanks! |
||
|
|
||
| #pragma once | ||
|
|
||
| #include <math.h> | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/framework/operator.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
|
|
||
| template <typename T> | ||
| class TruncKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| const Tensor* x = context.Input<Tensor>("X"); | ||
| Tensor* out = context.Output<Tensor>("Out"); | ||
|
|
||
| size_t numel = x->numel(); | ||
| const T* x_data = x->data<T>(); | ||
| T* out_data = out->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| for (size_t i = 0; i < numel; i++) { | ||
| out_data[i] = trunc(x_data[i]); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class TruncGradKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto* dx = context.Output<Tensor>(framework::GradVarName("X")); | ||
| T* dx_data = dx->mutable_data<T>(context.GetPlace()); | ||
|
|
||
| int numel = dx->numel(); | ||
| memset(dx_data, 0.0, numel * sizeof(T)); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| import unittest | ||
| import numpy as np | ||
| from op_test import OpTest | ||
| import paddle | ||
| import paddle.fluid.core as core | ||
| import paddle.fluid as fluid | ||
| from paddle.fluid import Program, program_guard | ||
|
|
||
| paddle.enable_static() | ||
|
|
||
|
|
||
| class TestTruncOp(OpTest): | ||
| def setUp(self): | ||
| self.op_type = "trunc" | ||
| self.dtype = np.float64 | ||
| np.random.seed(2021) | ||
| self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} | ||
| self.outputs = {'Out': (np.trunc(self.inputs['X']))} | ||
|
|
||
| def init_dtype_type(self): | ||
| self.dtype = np.float64 | ||
|
|
||
| def test_check_output(self): | ||
| self.check_output() | ||
|
|
||
| def test_check_grad(self): | ||
| self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) | ||
|
|
||
|
|
||
| class TestFloatTruncOp(TestTruncOp): | ||
| def init_dtype_type(self): | ||
| self.dtype = np.float32 | ||
|
|
||
|
|
||
| class TestIntTruncOp(TestTruncOp): | ||
| def init_dtype_type(self): | ||
| self.dtype = np.int32 | ||
|
|
||
|
|
||
| class TestTruncAPI(unittest.TestCase): | ||
| def setUp(self): | ||
| self.shape = [20, 20] | ||
| self.x = np.random.random((20, 20)).astype(np.float32) | ||
| self.place = paddle.CPUPlace() | ||
|
|
||
| def test_api_static(self): | ||
| paddle.enable_static() | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| x = paddle.fluid.data('X', self.shape) | ||
| out = paddle.trunc(x) | ||
| exe = paddle.static.Executor(self.place) | ||
| res = exe.run(feed={'X': self.x}, fetch_list=[out]) | ||
| out_ref = np.trunc(self.x) | ||
| for out in res: | ||
| self.assertEqual(np.allclose(out, out_ref, rtol=1e-08), True) | ||
|
|
||
| def test_api_dygraph(self): | ||
| paddle.disable_static(self.place) | ||
| x_tensor = paddle.to_tensor(self.x) | ||
| out = paddle.trunc(x_tensor) | ||
| out_ref = np.trunc(self.x) | ||
| self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) | ||
| paddle.enable_static() | ||
|
|
||
| def test_errors(self): | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| x = paddle.fluid.data('X', [20, 20], 'bool') | ||
| self.assertRaises(TypeError, paddle.trunc, x) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can register NoNeedBufferVars for X@GRAD to save memory, see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/op_notes_cn.html#id6 for details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!