Skip to content

Commit

Permalink
Python: restructure integration tests (#9331)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Our integration tests for the embedding services and memory are not set
up in a way that makes it easy for developers to add test cases.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
1. Restructure the tests so that it's easier to add new tests when a new
service is created.
2. Reduce duplicate code.
3. Add Ollama embedding service tests.

> Note: Besides the new Ollama tests, no other tests are added or
dropped.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Evan Mattson <[email protected]>
  • Loading branch information
TaoChenOSU and moonbox3 authored Oct 23, 2024
1 parent c069239 commit c304c8b
Show file tree
Hide file tree
Showing 53 changed files with 985 additions and 881 deletions.
40 changes: 20 additions & 20 deletions .github/workflows/python-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,16 @@ jobs:
- name: Install Ollama
if: matrix.os == 'ubuntu-latest'
run: |
if ${{ vars.OLLAMA_MODEL != '' }}; then
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &
sleep 5
fi
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &
sleep 5
- name: Pull model in Ollama
if: matrix.os == 'ubuntu-latest'
run: |
if ${{ vars.OLLAMA_MODEL != '' }}; then
ollama pull ${{ vars.OLLAMA_MODEL }}
ollama list
fi
ollama pull ${{ vars.OLLAMA_CHAT_MODEL_ID }}
ollama pull ${{ vars.OLLAMA_TEXT_MODEL_ID }}
ollama pull ${{ vars.OLLAMA_EMBEDDING_MODEL_ID }}
ollama list
- name: Google auth
uses: google-github-actions/auth@v2
with:
Expand Down Expand Up @@ -158,7 +156,9 @@ jobs:
MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }}
ANTHROPIC_API_KEY: ${{secrets.ANTHROPIC_API_KEY}}
ANTHROPIC_CHAT_MODEL_ID: ${{ vars.ANTHROPIC_CHAT_MODEL_ID }}
OLLAMA_MODEL: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_MODEL || '' }}" # phi3
OLLAMA_CHAT_MODEL_ID: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_CHAT_MODEL_ID || '' }}" # phi3
OLLAMA_TEXT_MODEL_ID: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_TEXT_MODEL_ID || '' }}" # phi3
OLLAMA_EMBEDDING_MODEL_ID: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_EMBEDDING_MODEL_ID || '' }}" # nomic-embed-text
GOOGLE_AI_GEMINI_MODEL_ID: ${{ vars.GOOGLE_AI_GEMINI_MODEL_ID }}
GOOGLE_AI_EMBEDDING_MODEL_ID: ${{ vars.GOOGLE_AI_EMBEDDING_MODEL_ID }}
GOOGLE_AI_API_KEY: ${{ secrets.GOOGLE_AI_API_KEY }}
Expand Down Expand Up @@ -225,18 +225,16 @@ jobs:
- name: Install Ollama
if: matrix.os == 'ubuntu-latest'
run: |
if ${{ vars.OLLAMA_MODEL != '' }}; then
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &
sleep 5
fi
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &
sleep 5
- name: Pull model in Ollama
if: matrix.os == 'ubuntu-latest'
run: |
if ${{ vars.OLLAMA_MODEL != '' }}; then
ollama pull ${{ vars.OLLAMA_MODEL }}
ollama list
fi
ollama pull ${{ vars.OLLAMA_CHAT_MODEL_ID }}
ollama pull ${{ vars.OLLAMA_TEXT_MODEL_ID }}
ollama pull ${{ vars.OLLAMA_EMBEDDING_MODEL_ID }}
ollama list
- name: Google auth
uses: google-github-actions/auth@v2
with:
Expand Down Expand Up @@ -294,7 +292,9 @@ jobs:
MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }}
ANTHROPIC_API_KEY: ${{secrets.ANTHROPIC_API_KEY}}
ANTHROPIC_CHAT_MODEL_ID: ${{ vars.ANTHROPIC_CHAT_MODEL_ID }}
OLLAMA_MODEL: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_MODEL || '' }}" # phi3
OLLAMA_CHAT_MODEL_ID: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_CHAT_MODEL_ID || '' }}" # phi3
OLLAMA_TEXT_MODEL_ID: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_TEXT_MODEL_ID || '' }}" # phi3
OLLAMA_EMBEDDING_MODEL_ID: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_EMBEDDING_MODEL_ID || '' }}" # nomic-embed-text
GOOGLE_AI_GEMINI_MODEL_ID: ${{ vars.GOOGLE_AI_GEMINI_MODEL_ID }}
GOOGLE_AI_EMBEDDING_MODEL_ID: ${{ vars.GOOGLE_AI_EMBEDDING_MODEL_ID }}
GOOGLE_AI_API_KEY: ${{ secrets.GOOGLE_AI_API_KEY }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

@experimental_class
class AnthropicChatCompletion(ChatCompletionClientBase):
"""Antropic ChatCompletion class."""
"""Anthropic ChatCompletion class."""

MODEL_PROVIDER_NAME: ClassVar[str] = "anthropic"
SUPPORTS_FUNCTION_CALLING: ClassVar[bool] = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def prepare_settings_dict(self, **kwargs) -> dict[str, Any]:
class GoogleAIEmbeddingPromptExecutionSettings(PromptExecutionSettings):
"""Google AI Embedding Prompt Execution Settings."""

output_dimensionality: int | None = None
output_dimensionality: int | None = Field(None, le=768)
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
)
resolved_device = f"cuda:{device}" if device >= 0 and torch.cuda.is_available() else "cpu"
super().__init__(
service_id=service_id,
service_id=service_id or ai_model_id,
ai_model_id=ai_model_id,
task=task,
device=resolved_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
resolved_device = f"cuda:{device}" if device >= 0 and torch.cuda.is_available() else "cpu"
super().__init__(
ai_model_id=ai_model_id,
service_id=service_id,
service_id=service_id or ai_model_id,
device=resolved_device,
generator=sentence_transformers.SentenceTransformer(model_name_or_path=ai_model_id, device=resolved_device),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ class OllamaSettings(KernelBaseSettings):
settings are missing.
Required settings for prefix 'OLLAMA' are:
- model: str - Model name. (Env var OLLAMA_MODEL)
- chat_model_id: str - The chat model ID. (Env var OLLAMA_CHAT_MODEL_ID)
- text_model_id: str - The text model ID. (Env var OLLAMA_TEXT_MODEL_ID)
- embedding_model_id: str - The embedding model ID. (Env var OLLAMA_EMBEDDING_MODEL_ID)
Optional settings for prefix 'OLLAMA' are:
- host: HttpsUrl - The endpoint of the Ollama service. (Env var OLLAMA_HOST)
"""

env_prefix: ClassVar[str] = "OLLAMA_"

model: str
chat_model_id: str | None = None
text_model_id: str | None = None
embedding_model_id: str | None = None
host: str | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,20 @@ def __init__(
"""
try:
ollama_settings = OllamaSettings.create(
model=ai_model_id,
chat_model_id=ai_model_id,
host=host,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Ollama settings.", ex) from ex

if not ollama_settings.model:
raise ServiceInitializationError("Please provide ai_model_id or OLLAMA_MODEL env variable is required")
if not ollama_settings.chat_model_id:
raise ServiceInitializationError("Ollama chat model ID is required.")

super().__init__(
service_id=service_id or ollama_settings.model,
ai_model_id=ollama_settings.model,
service_id=service_id or ollama_settings.chat_model_id,
ai_model_id=ollama_settings.chat_model_id,
client=client or AsyncClient(host=ollama_settings.host),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def __init__(
"""
try:
ollama_settings = OllamaSettings.create(
model=ai_model_id,
text_model_id=ai_model_id,
host=host,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Ollama settings.", ex) from ex

if not ollama_settings.model:
raise ServiceInitializationError("Please provide ai_model_id or OLLAMA_MODEL env variable is required")
if not ollama_settings.text_model_id:
raise ServiceInitializationError("Ollama text model ID is required.")

super().__init__(
service_id=service_id or ollama_settings.model,
ai_model_id=ollama_settings.model,
service_id=service_id or ollama_settings.text_model_id,
ai_model_id=ollama_settings.text_model_id,
client=client or AsyncClient(host=ollama_settings.host),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,20 @@ def __init__(
"""
try:
ollama_settings = OllamaSettings.create(
model=ai_model_id,
embedding_model_id=ai_model_id,
host=host,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Ollama settings.", ex) from ex

if not ollama_settings.embedding_model_id:
raise ServiceInitializationError("Ollama embedding model ID is not set.")

super().__init__(
service_id=service_id or ollama_settings.model,
ai_model_id=ollama_settings.model,
service_id=service_id or ollama_settings.embedding_model_id,
ai_model_id=ollama_settings.embedding_model_id,
client=client or AsyncClient(host=ollama_settings.host),
)

Expand Down
3 changes: 2 additions & 1 deletion python/semantic_kernel/contents/binary_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
"""Create a Binary Content object, either from a data_uri or data.
Args:
uri (Url | None): The reference uri of the content.
uri (Url | str | None): The reference uri of the content.
data_uri (DataUrl | None): The data uri of the content.
data (str | bytes | None): The data of the content.
data_format (str | None): The format of the data (e.g. base64).
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
_data_uri = DataUri(
data_bytes=data, data_format=data_format, mime_type=mime_type or self.default_mime_type
)

if uri is not None:
if isinstance(uri, str) and os.path.exists(uri):
uri = str(FilePath(uri))
Expand Down
31 changes: 23 additions & 8 deletions python/tests/integration/completions/chat_completion_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from azure.identity import DefaultAzureCredential
from openai import AsyncAzureOpenAI

from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import (
AnthropicChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.anthropic.services.anthropic_chat_completion import AnthropicChatCompletion
from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_prompt_execution_settings import (
AzureAIInferenceChatPromptExecutionSettings,
)
Expand Down Expand Up @@ -51,20 +55,30 @@
from semantic_kernel.kernel import Kernel
from semantic_kernel.kernel_pydantic import KernelBaseModel
from tests.integration.completions.completion_test_base import CompletionTestBase, ServiceType
from tests.integration.completions.test_utils import is_service_setup_for_testing
from tests.integration.test_utils import is_service_setup_for_testing

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

# This can later also be simplified as map probably
mistral_ai_setup: bool = is_service_setup_for_testing("MISTRALAI_API_KEY")
ollama_setup: bool = is_service_setup_for_testing("OLLAMA_MODEL")
google_ai_setup: bool = is_service_setup_for_testing("GOOGLE_AI_API_KEY")
vertex_ai_setup: bool = is_service_setup_for_testing("VERTEX_AI_PROJECT_ID")
anthropic_setup: bool = is_service_setup_for_testing("ANTHROPIC_API_KEY")
onnx_setup: bool = is_service_setup_for_testing("ONNX_GEN_AI_CHAT_MODEL_FOLDER")
# Make sure all services are setup for before running the tests
# The following exceptions apply:
# 1. OpenAI and Azure OpenAI services are always setup for testing.
# 2. Bedrock services don't use API keys and model providers are tested individually,
# so no environment variables are required.
mistral_ai_setup: bool = is_service_setup_for_testing(
["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"], raise_if_not_set=False
) # We don't have a MistralAI deployment
ollama_setup: bool = is_service_setup_for_testing(["OLLAMA_CHAT_MODEL_ID"])
google_ai_setup: bool = is_service_setup_for_testing(["GOOGLE_AI_API_KEY", "GOOGLE_AI_GEMINI_MODEL_ID"])
vertex_ai_setup: bool = is_service_setup_for_testing(["VERTEX_AI_PROJECT_ID", "VERTEX_AI_GEMINI_MODEL_ID"])
onnx_setup: bool = is_service_setup_for_testing(
["ONNX_GEN_AI_CHAT_MODEL_FOLDER"], raise_if_not_set=False
) # Tests are optional for ONNX
anthropic_setup: bool = is_service_setup_for_testing(
["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"], raise_if_not_set=False
) # We don't have an Anthropic deployment

skip_on_mac_available = platform.system() == "Darwin"
if not skip_on_mac_available:
Expand Down Expand Up @@ -122,6 +136,7 @@ def services(self) -> dict[str, tuple[ServiceType, type[PromptExecutionSettings]
"azure": (AzureChatCompletion(), AzureChatPromptExecutionSettings),
"azure_custom_client": (azure_custom_client, AzureChatPromptExecutionSettings),
"azure_ai_inference": (azure_ai_inference_client, AzureAIInferenceChatPromptExecutionSettings),
"anthropic": (AnthropicChatCompletion() if anthropic_setup else None, AnthropicChatPromptExecutionSettings),
"mistral_ai": (
MistralAIChatCompletion() if mistral_ai_setup else None,
MistralAIChatPromptExecutionSettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
vertex_ai_setup,
)
from tests.integration.completions.completion_test_base import ServiceType
from tests.integration.completions.test_utils import retry
from tests.integration.test_utils import retry

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -667,6 +667,7 @@ class FunctionChoiceTestTypes(str, Enum):
]
],
{"test_type": FunctionChoiceTestTypes.NON_AUTO},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_anthropic_claude_tool_call_non_auto",
),
pytest.param(
Expand Down Expand Up @@ -698,6 +699,7 @@ class FunctionChoiceTestTypes(str, Enum):
],
],
{"test_type": FunctionChoiceTestTypes.FLOW},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_anthropic_claude_tool_call_flow",
),
pytest.param(
Expand All @@ -716,6 +718,7 @@ class FunctionChoiceTestTypes(str, Enum):
]
],
{"test_type": FunctionChoiceTestTypes.AUTO},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_anthropic_claude_tool_call_auto_complex_return_type",
),
# endregion
Expand All @@ -737,6 +740,7 @@ class FunctionChoiceTestTypes(str, Enum):
]
],
{"test_type": FunctionChoiceTestTypes.NON_AUTO},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_cohere_command_tool_call_non_auto",
),
pytest.param(
Expand Down Expand Up @@ -768,6 +772,7 @@ class FunctionChoiceTestTypes(str, Enum):
],
],
{"test_type": FunctionChoiceTestTypes.FLOW},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_cohere_command_tool_call_flow",
),
pytest.param(
Expand All @@ -786,6 +791,7 @@ class FunctionChoiceTestTypes(str, Enum):
]
],
{"test_type": FunctionChoiceTestTypes.AUTO},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_cohere_command_tool_call_auto_complex_return_type",
),
# endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
vertex_ai_setup,
)
from tests.integration.completions.completion_test_base import ServiceType
from tests.integration.completions.test_utils import retry
from tests.integration.test_utils import retry

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -221,6 +221,7 @@
ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="Where was it made?")]),
],
{},
marks=pytest.mark.skip(reason="Skipping due to occasional throttling from Bedrock."),
id="bedrock_anthropic_claude_image_input_file",
),
],
Expand Down
Loading

0 comments on commit c304c8b

Please sign in to comment.