Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions debug_toolbar/panels/sql/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.db import connections
from django.utils.functional import cached_property

from debug_toolbar.panels.sql.utils import reformat_sql
from debug_toolbar.panels.sql.utils import is_select_query, reformat_sql


class SQLSelectForm(forms.Form):
Expand All @@ -27,7 +27,7 @@ class SQLSelectForm(forms.Form):
def clean_raw_sql(self):
value = self.cleaned_data["raw_sql"]

if not value.lower().strip().startswith("select"):
if not is_select_query(value):
raise ValidationError("Only 'select' queries are allowed.")

return value
Expand Down
10 changes: 6 additions & 4 deletions debug_toolbar/panels/sql/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from debug_toolbar.panels.sql import views
from debug_toolbar.panels.sql.forms import SQLSelectForm
from debug_toolbar.panels.sql.tracking import wrap_cursor
from debug_toolbar.panels.sql.utils import contrasting_color_generator, reformat_sql
from debug_toolbar.panels.sql.utils import (
contrasting_color_generator,
is_select_query,
reformat_sql,
)
from debug_toolbar.utils import render_stacktrace


Expand Down Expand Up @@ -266,9 +270,7 @@ def generate_stats(self, request, response):
query["sql"] = reformat_sql(query["sql"], with_toggle=True)

query["is_slow"] = query["duration"] > sql_warning_threshold
query["is_select"] = (
query["raw_sql"].lower().lstrip().startswith("select")
)
query["is_select"] = is_select_query(query["raw_sql"])

query["rgb_color"] = self._databases[alias]["rgb_color"]
try:
Expand Down
5 changes: 5 additions & 0 deletions debug_toolbar/panels/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def process(stmt):
return "".join(escaped_value(token) for token in stmt.flatten())


def is_select_query(sql):
# UNION queries can start with "(".
return sql.lower().lstrip(" (").startswith("select")


def reformat_sql(sql, *, with_toggle=False):
formatted = parse_sql(sql)
if not with_toggle:
Expand Down
1 change: 1 addition & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Change log

Pending
-------
* Support select and explain buttons for ``UNION`` queries on PostgreSQL.

4.4.6 (2024-07-10)
------------------
Expand Down
10 changes: 10 additions & 0 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,16 @@ def test_similar_and_duplicate_grouping(self):
self.assertNotEqual(queries[0]["similar_color"], queries[3]["similar_color"])
self.assertNotEqual(queries[0]["duplicate_color"], queries[3]["similar_color"])

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this is only valid for postgres? I feel like the other db drivers should also pass this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look at the Django code, this can only happen if the DB backend has supports_slicing_ordering_in_compound = True. But I see now it is more widely supported than I thought. Shall I rather use

@unittest.skipUnless(connection.vendor != "sqlite", "...")

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why the skipUnless is even necessary? The test should pass on all database engines anyway, or not?

)
def test_explain(self):
list(User.objects.filter(id__lt=20).union(User.objects.filter(id__gt=10)))
response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)
query = self.panel._queries[0]
self.assertTrue(query["is_select"])


class SQLPanelMultiDBTestCase(BaseMultiDBTestCase):
panel_id = "SQLPanel"
Expand Down
20 changes: 20 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,26 @@ def test_sql_explain_checks_show_toolbar(self):
)
self.assertEqual(response.status_code, 404)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_sql_explain_union_query(self):
url = "/__debug__/sql_explain/"
data = {
"signed": SignedDataForm.sign(
{
"sql": "(SELECT * FROM auth_user) UNION (SELECT * from auth_user)",
"raw_sql": "(SELECT * FROM auth_user) UNION (SELECT * from auth_user)",
"params": "{}",
"alias": "default",
"duration": "0",
}
)
}

response = self.client.post(url, data)
self.assertEqual(response.status_code, 200)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
Expand Down