Skip to content

Commit 72d3697

Browse files
authored
[Feature] add paddle.trunc (#33371)
* new api trunc, test=develop
1 parent 32e3353 commit 72d3697

File tree

8 files changed

+396
-0
lines changed

8 files changed

+396
-0
lines changed

paddle/fluid/framework/unused_var_check.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
7575
"data_norm_grad", // 0
7676
"update_loss_scaling", // 0
7777
"fused_embedding_eltwise_layernorm", // 0
78+
"trunc_grad", // 1
7879
});
7980
return *allow_set;
8081
}

paddle/fluid/operators/trunc_op.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/trunc_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class TruncOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext *ctx) const override {
25+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc");
26+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc");
27+
auto input_dims = ctx->GetInputDim("X");
28+
ctx->SetOutputDim("Out", input_dims);
29+
ctx->ShareLoD("X", /*->*/ "Out");
30+
}
31+
};
32+
33+
class TruncOpMaker : public framework::OpProtoAndCheckerMaker {
34+
public:
35+
void Make() override {
36+
AddInput("X", "(Tensor), The input tensor of trunc op.");
37+
AddOutput("Out", "(Tensor), The output tensor of trunc op.");
38+
AddComment(R"DOC(
39+
Trunc Operator.
40+
Returns a new tensor with the truncated integer values of input.
41+
$$out = trunc(x)$$
42+
)DOC");
43+
}
44+
};
45+
46+
class TruncGradOp : public framework::OperatorWithKernel {
47+
public:
48+
using framework::OperatorWithKernel::OperatorWithKernel;
49+
50+
void InferShape(framework::InferShapeContext *ctx) const override {
51+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
52+
framework::GradVarName("Out"), "TruncGrad");
53+
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
54+
framework::GradVarName("X"), "TruncGrad");
55+
56+
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
57+
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
58+
}
59+
};
60+
61+
template <typename T>
62+
class TruncGradOpMaker : public framework::SingleGradOpMaker<T> {
63+
public:
64+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
65+
66+
void Apply(GradOpPtr<T> retv) const override {
67+
retv->SetType("trunc_grad");
68+
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
69+
retv->SetAttrMap(this->Attrs());
70+
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle
76+
77+
namespace ops = paddle::operators;
78+
REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker,
79+
ops::TruncGradOpMaker<paddle::framework::OpDesc>,
80+
ops::TruncGradOpMaker<paddle::imperative::OpBase>);
81+
82+
REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp);
83+
84+
REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel<float>, ops::TruncKernel<double>,
85+
ops::TruncKernel<int>, ops::TruncKernel<int64_t>);
86+
87+
REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel<float>,
88+
ops::TruncGradKernel<double>, ops::TruncGradKernel<int>,
89+
ops::TruncGradKernel<int64_t>);

paddle/fluid/operators/trunc_op.cu

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#include "paddle/fluid/operators/trunc_op.h"
15+
#include "paddle/fluid/platform/cuda_primitives.h"
16+
#include "paddle/fluid/platform/gpu_info.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using platform::PADDLE_CUDA_NUM_THREADS;
22+
23+
template <typename T>
24+
class TruncFunctor {
25+
public:
26+
__device__ TruncFunctor(const T x) : x_(x) {}
27+
__device__ T operator()() { return trunc(x_); }
28+
29+
public:
30+
const T x_;
31+
};
32+
33+
template <>
34+
class TruncFunctor<int> {
35+
public:
36+
__device__ TruncFunctor(const int x) : x_(x) {}
37+
__device__ int operator()() { return x_; }
38+
39+
public:
40+
const int x_;
41+
};
42+
43+
template <>
44+
class TruncFunctor<int64_t> {
45+
public:
46+
__device__ TruncFunctor(const int64_t x) : x_(x) {}
47+
__device__ int64_t operator()() { return x_; }
48+
49+
public:
50+
const int64_t x_;
51+
};
52+
53+
template <typename T>
54+
__global__ void Trunc(const T* x, T* out, int64_t N) {
55+
CUDA_KERNEL_LOOP(index, N) {
56+
TruncFunctor<T> functor(x[index]);
57+
out[index] = functor();
58+
}
59+
}
60+
61+
template <typename T>
62+
__global__ void TruncGrad(T* dx, int64_t N) {
63+
CUDA_KERNEL_LOOP(index, N) { dx[index] = static_cast<T>(0.0); }
64+
}
65+
66+
template <typename T>
67+
class TruncCUDAKernel : public framework::OpKernel<T> {
68+
public:
69+
void Compute(const framework::ExecutionContext& context) const override {
70+
auto* x = context.Input<Tensor>("X");
71+
auto* out = context.Output<Tensor>("Out");
72+
73+
const auto* x_data = x->data<T>();
74+
auto* out_data = out->mutable_data<T>(context.GetPlace());
75+
76+
int64_t numel = x->numel();
77+
78+
int theads = PADDLE_CUDA_NUM_THREADS;
79+
int blocks = (numel + theads - 1) / theads;
80+
81+
Trunc<<<blocks, theads>>>(x_data, out_data, numel);
82+
}
83+
};
84+
85+
template <typename T>
86+
class TruncCUDAGradKernel : public framework::OpKernel<T> {
87+
public:
88+
void Compute(const framework::ExecutionContext& context) const override {
89+
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
90+
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
91+
92+
const auto* dout_data = dout->data<T>();
93+
auto* dx_data = dx->mutable_data<T>(context.GetPlace());
94+
95+
int64_t numel = dout->numel();
96+
97+
int theads = PADDLE_CUDA_NUM_THREADS;
98+
int blocks = (numel + theads - 1) / theads;
99+
100+
TruncGrad<<<blocks, theads>>>(dx_data, numel);
101+
}
102+
};
103+
104+
} // namespace operators
105+
} // namespace paddle
106+
107+
namespace ops = paddle::operators;
108+
REGISTER_OP_CUDA_KERNEL(trunc, ops::TruncCUDAKernel<float>,
109+
ops::TruncCUDAKernel<double>, ops::TruncCUDAKernel<int>,
110+
ops::TruncCUDAKernel<int64_t>);
111+
112+
REGISTER_OP_CUDA_KERNEL(trunc_grad, ops::TruncCUDAGradKernel<float>,
113+
ops::TruncCUDAGradKernel<double>,
114+
ops::TruncCUDAGradKernel<int>,
115+
ops::TruncCUDAGradKernel<int64_t>);

paddle/fluid/operators/trunc_op.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#pragma once
15+
16+
#include <math.h>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/operator.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
25+
template <typename T>
26+
class TruncKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& context) const override {
29+
const Tensor* x = context.Input<Tensor>("X");
30+
Tensor* out = context.Output<Tensor>("Out");
31+
32+
size_t numel = x->numel();
33+
const T* x_data = x->data<T>();
34+
T* out_data = out->mutable_data<T>(context.GetPlace());
35+
36+
for (size_t i = 0; i < numel; i++) {
37+
out_data[i] = trunc(x_data[i]);
38+
}
39+
}
40+
};
41+
42+
template <typename T>
43+
class TruncGradKernel : public framework::OpKernel<T> {
44+
public:
45+
void Compute(const framework::ExecutionContext& context) const override {
46+
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
47+
T* dx_data = dx->mutable_data<T>(context.GetPlace());
48+
49+
int numel = dx->numel();
50+
memset(dx_data, 0.0, numel * sizeof(T));
51+
}
52+
};
53+
54+
} // namespace operators
55+
} // namespace paddle

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@
205205
from .tensor.math import prod # noqa: F401
206206
from .tensor.math import broadcast_shape # noqa: F401
207207
from .tensor.math import conj # noqa: F401
208+
from .tensor.math import trunc # noqa: F401
208209
from .tensor.math import digamma # noqa: F401
209210
from .tensor.math import neg # noqa: F401
210211
from .tensor.math import lgamma # noqa: F401
@@ -490,6 +491,7 @@
490491
'log10',
491492
'concat',
492493
'check_shape',
494+
'trunc'
493495
'digamma',
494496
'standard_normal'
495497
]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from op_test import OpTest
20+
import paddle
21+
import paddle.fluid.core as core
22+
import paddle.fluid as fluid
23+
from paddle.fluid import Program, program_guard
24+
25+
paddle.enable_static()
26+
27+
28+
class TestTruncOp(OpTest):
29+
def setUp(self):
30+
self.op_type = "trunc"
31+
self.dtype = np.float64
32+
np.random.seed(2021)
33+
self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)}
34+
self.outputs = {'Out': (np.trunc(self.inputs['X']))}
35+
36+
def init_dtype_type(self):
37+
self.dtype = np.float64
38+
39+
def test_check_output(self):
40+
self.check_output()
41+
42+
def test_check_grad(self):
43+
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5)
44+
45+
46+
class TestFloatTruncOp(TestTruncOp):
47+
def init_dtype_type(self):
48+
self.dtype = np.float32
49+
50+
51+
class TestIntTruncOp(TestTruncOp):
52+
def init_dtype_type(self):
53+
self.dtype = np.int32
54+
55+
56+
class TestTruncAPI(unittest.TestCase):
57+
def setUp(self):
58+
self.shape = [20, 20]
59+
self.x = np.random.random((20, 20)).astype(np.float32)
60+
self.place = paddle.CPUPlace()
61+
62+
def test_api_static(self):
63+
paddle.enable_static()
64+
with paddle.static.program_guard(paddle.static.Program()):
65+
x = paddle.fluid.data('X', self.shape)
66+
out = paddle.trunc(x)
67+
exe = paddle.static.Executor(self.place)
68+
res = exe.run(feed={'X': self.x}, fetch_list=[out])
69+
out_ref = np.trunc(self.x)
70+
for out in res:
71+
self.assertEqual(np.allclose(out, out_ref, rtol=1e-08), True)
72+
73+
def test_api_dygraph(self):
74+
paddle.disable_static(self.place)
75+
x_tensor = paddle.to_tensor(self.x)
76+
out = paddle.trunc(x_tensor)
77+
out_ref = np.trunc(self.x)
78+
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
79+
paddle.enable_static()
80+
81+
def test_errors(self):
82+
with paddle.static.program_guard(paddle.static.Program()):
83+
x = paddle.fluid.data('X', [20, 20], 'bool')
84+
self.assertRaises(TypeError, paddle.trunc, x)
85+
86+
87+
if __name__ == "__main__":
88+
unittest.main()

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
from .math import any # noqa: F401
163163
from .math import broadcast_shape # noqa: F401
164164
from .math import conj # noqa: F401
165+
from .math import trunc # noqa: F401
165166
from .math import digamma # noqa: F401
166167
from .math import neg # noqa: F401
167168
from .math import lgamma # noqa: F401
@@ -349,5 +350,6 @@
349350
'shape',
350351
'real',
351352
'imag',
353+
'trunc'
352354
'digamma'
353355
]

0 commit comments

Comments
 (0)