Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self) -> None:
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
_azure_endpoint: httpx.URL | None
_azure_deployment: str | None
_is_v1_api: bool

@override
def _build_request(
Expand All @@ -60,10 +61,12 @@ def _build_request(
*,
retries_taken: int = 0,
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"
# v1 API doesn't use /deployments/{model}/ path - model is passed in body
if not getattr(self, '_is_v1_api', False):
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"

return super()._build_request(options, retries_taken=retries_taken)

Expand Down Expand Up @@ -208,6 +211,9 @@ def __init__(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)

# Check if using v1 API format (new Azure OpenAI API)
_is_v1_api = api_version in ("v1", "latest", "preview")

if default_query is None:
default_query = {"api-version": api_version}
else:
Expand All @@ -222,7 +228,10 @@ def __init__(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)

if azure_deployment is not None:
if _is_v1_api:
# v1 API uses /openai/v1/ path without /deployments/
base_url = f"{azure_endpoint.rstrip('/')}/openai/v1"
elif azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
Expand Down Expand Up @@ -253,6 +262,7 @@ def __init__(
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
self._is_v1_api = _is_v1_api

@override
def copy(
Expand Down Expand Up @@ -489,6 +499,9 @@ def __init__(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)

# Check if using v1 API format (new Azure OpenAI API)
_is_v1_api = api_version in ("v1", "latest", "preview")

if default_query is None:
default_query = {"api-version": api_version}
else:
Expand All @@ -503,7 +516,10 @@ def __init__(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)

if azure_deployment is not None:
if _is_v1_api:
# v1 API uses /openai/v1/ path without /deployments/
base_url = f"{azure_endpoint.rstrip('/')}/openai/v1"
elif azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
Expand Down Expand Up @@ -534,6 +550,7 @@ def __init__(
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
self._is_v1_api = _is_v1_api

@override
def copy(
Expand Down
85 changes: 85 additions & 0 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,88 @@ def test_client_sets_base_url(client: Client) -> None:
)
)
assert req.url == "https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"


# Tests for v1 API support
class TestAzureV1API:
"""Tests for Azure OpenAI v1/latest/preview API support."""

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_base_url(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should use /openai/v1/ base URL."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
assert "/openai/v1" in str(client.base_url)
assert "/deployments/" not in str(client.base_url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_no_deployments_path(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should NOT add /deployments/{model}/ to the path."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "gpt-4o"},
)
)
assert "/deployments/" not in str(req.url)
assert "/openai/v1/chat/completions" in str(req.url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_has_query_param(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should still include ?api-version= query param."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "gpt-4o"},
)
)
assert f"api-version={api_version}" in str(req.url)

@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_traditional_api_still_works(self, client_cls: type[Client]) -> None:
"""Traditional API should still use /deployments/ path."""
client = client_cls(
api_version="2024-10-21",
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "gpt-4o"},
)
)
assert "/deployments/gpt-4o/" in str(req.url)
assert "api-version=2024-10-21" in str(req.url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
def test_v1_api_ignores_azure_deployment_param(self, api_version: str) -> None:
"""v1 API should ignore azure_deployment parameter since model is in body."""
client = AzureOpenAI(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
azure_deployment="ignored-deployment",
)
# base_url should still be /openai/v1, not /openai/deployments/ignored-deployment
assert "/openai/v1" in str(client.base_url)
assert "/deployments/" not in str(client.base_url)