fix: bug: ModelBuilder overwrites user-provided HF_MODEL_ID for DJL Serving, preventi (5529)#5734
Conversation
…erving, preventi (5529)
sagemaker-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
The fix correctly replaces .update() with .setdefault() to preserve user-provided HF_MODEL_ID values, which is a clean and minimal change. The tests cover all affected methods with both preservation and default-setting scenarios. However, there are several issues with the test file: it uses unittest style instead of pytest conventions, has lines exceeding 100 characters, and has trailing whitespace in the source file.
| # Configure HuggingFace model support | ||
| if not self._is_jumpstart_model_id(): | ||
| self.env_vars.update({"HF_MODEL_ID": self.model}) | ||
| self.env_vars.setdefault("HF_MODEL_ID", self.model) |
There was a problem hiding this comment.
Nit: There appears to be trailing whitespace on this line (after setdefault). Same issue on lines 215, 323, 535. While CI formatting tools may catch this, it's worth cleaning up.
| @@ -0,0 +1,275 @@ | |||
| """Unit tests to verify HF_MODEL_ID is not overwritten when user provides it.""" | |||
There was a problem hiding this comment.
The SDK uses pytest as the test framework (per unit test standards). This file uses unittest.TestCase with self.assertEqual. Please refactor to use pytest conventions:
- Use plain test functions or classes without inheriting
TestCase - Use
assertstatements instead ofself.assertEqual - Use
@pytest.fixtureinstead of helper functions for shared setup - Remove
if __name__ == '__main__': unittest.main()
Example:
def test_djl_preserves_user_provided_s3_uri(...):
...
assert builder.env_vars["HF_MODEL_ID"] == s3_path|
|
||
| @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") | ||
| @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") | ||
| @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) |
There was a problem hiding this comment.
This line exceeds 100 characters (the SDK's line length limit). Several other decorator lines in this file also exceed the limit (lines 69, 97, 98, etc.). Please wrap long lines to stay within 100 characters.
@patch(
"sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree",
return_value=1,
)| @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") | ||
| @patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None) | ||
| @patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1) | ||
| @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) |
There was a problem hiding this comment.
Long function signature exceeds 100 characters. Please wrap parameters across multiple lines.
| from sagemaker.serve.mode.function_pointers import Mode | ||
|
|
||
|
|
||
| def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): |
There was a problem hiding this comment.
Consider using @pytest.fixture for the mock builder creation instead of a plain helper function. This would be more idiomatic pytest and allow parameterization:
@pytest.fixture
def mock_builder():
"""Create a mock builder with common attributes set."""
...| from sagemaker.serve.mode.function_pointers import Mode | ||
|
|
||
|
|
||
| def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"): |
There was a problem hiding this comment.
Missing type annotations on the helper function. Per SDK conventions, public/utility functions should have type hints:
def _create_mock_builder(
env_vars: dict[str, str] | None = None,
model: str = "Qwen/Qwen3-VL-4B-Instruct",
) -> MagicMock:|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Remove if __name__ == '__main__': unittest.main() — the SDK runs tests via pytest, not unittest's runner.
| @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1) | ||
| def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config): | ||
| """User-provided S3 URI for HF_MODEL_ID should not be overwritten.""" | ||
| mock_hf_config.return_value = {} |
There was a problem hiding this comment.
Consider using @pytest.mark.parametrize to reduce duplication across the test classes. Many tests follow the same pattern (preserve vs. set default) across different server types. For example:
@pytest.mark.parametrize("build_method,server_type,patches", [
("_build_for_djl", ModelServer.DJL_SERVING, [...]),
("_build_for_tgi", ModelServer.TGI, [...]),
...
])
def test_preserves_user_provided_hf_model_id(build_method, server_type, patches):
...This would significantly reduce the test file size while maintaining coverage.
🤖 Iteration #1 — Review Comments AddressedDescriptionFix bug where ProblemMultiple FixThe source file already correctly uses Test ChangesRewrote the test file to follow SDK conventions:
TestingAll tests verify that:
This is backward compatible — existing behavior where Comments reviewed: 8
|
Description
The bug is in sagemaker-serve/src/sagemaker/serve/model_builder_servers.py. Multiple build_for* methods unconditionally overwrite HF_MODEL_ID using self.env_vars.update({'HF_MODEL_ID': ...}), which destroys any user-provided value (e.g., an S3 URI for DJL serving). The fix is to replace each .update({'HF_MODEL_ID': ...}) call with .setdefault('HF_MODEL_ID', ...) so that user-provided values are preserved. This affects 6 methods: _build_for_torchserve, _build_for_tgi, _build_for_djl, _build_for_triton, _build_for_tei, and _build_for_transformers.
Related Issue
Related issue: 5529
Changes Made
sagemaker-serve/src/sagemaker/serve/model_builder_servers.pysagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.pyAI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat