Skip to content

Commit 68867bf

Browse files
Prevent unitialized variable use in grappler.
PiperOrigin-RevId: 399702928 Change-Id: Id7e75451fbff297692dfb687f60ea04b25c96b24
1 parent fad123a commit 68867bf

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

tensorflow/core/grappler/optimizers/auto_parallel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
152152
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, item.fetch, &train_nodes));
153153
LOG(INFO) << "Number of training nodes: " << train_nodes.size();
154154

155-
const NodeDef* dequeue_node;
155+
const NodeDef* dequeue_node = nullptr;
156156
for (const auto& train_node : train_nodes) {
157157
if (IsDequeueOp(*train_node)) {
158158
dequeue_node = train_node;

tensorflow/core/grappler/optimizers/auto_parallel_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,30 @@ TEST_F(AutoParallelTest, SimpleParallel) {
126126
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
127127
}
128128

129+
TEST_F(AutoParallelTest, SimpleParallelNoDequeue) {
130+
tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
131+
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
132+
Output constant_c = ops::Const(s.WithOpName("constant_c"), 1.0f, {1});
133+
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
134+
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
135+
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
136+
Output add = ops::AddN(s.WithOpName("add"), {constant_a, constant_c});
137+
Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
138+
Output apply_gradient = ops::ApplyGradientDescent(
139+
s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});
140+
141+
GrapplerItem item;
142+
item.init_ops.push_back("assign");
143+
item.fetch.push_back("apply_gradient");
144+
item.init_ops.push_back("assign");
145+
TF_CHECK_OK(s.ToGraphDef(&item.graph));
146+
147+
AutoParallel parallel(2);
148+
GraphDef output;
149+
Status status = parallel.Optimize(nullptr, item, &output);
150+
TF_EXPECT_OK(status);
151+
}
152+
129153
} // namespace
130154
} // namespace grappler
131155
} // namespace tensorflow

0 commit comments

Comments
 (0)