diff --git a/source/cbor.c b/source/cbor.c index 602ccb21d..8cc9f76cb 100644 --- a/source/cbor.c +++ b/source/cbor.c @@ -183,30 +183,68 @@ static PyObject *s_cbor_encoder_write_pyobject(struct aws_cbor_encoder *encoder, /** * TODO: timestamp <-> datetime?? Decimal fraction <-> decimal?? */ - if (PyLong_CheckExact(py_object)) { - return s_cbor_encoder_write_pylong(encoder, py_object); - } else if (PyFloat_CheckExact(py_object)) { - return s_cbor_encoder_write_pyobject_as_float(encoder, py_object); - } else if (PyBool_Check(py_object)) { - return s_cbor_encoder_write_pyobject_as_bool(encoder, py_object); - } else if (PyBytes_CheckExact(py_object)) { - return s_cbor_encoder_write_pyobject_as_bytes(encoder, py_object); - } else if (PyUnicode_Check(py_object)) { + + /* Handle None first as it's a singleton, not a type */ + if (py_object == Py_None) { + aws_cbor_encoder_write_null(encoder); + Py_RETURN_NONE; + } + + /* Get type once for efficiency - PyObject_Type returns a new reference */ + /* https://docs.python.org/3/c-api/structures.html#c.Py_TYPE is not a stable API until 3.14, so that we cannot use + * it. */ + PyObject *type = PyObject_Type(py_object); + if (!type) { + return NULL; + } + + PyObject *result = NULL; + + /* Exact type matches first (no subclasses) - fast path */ + if (type == (PyObject *)&PyLong_Type) { + result = s_cbor_encoder_write_pylong(encoder, py_object); + } else if (type == (PyObject *)&PyFloat_Type) { + result = s_cbor_encoder_write_pyobject_as_float(encoder, py_object); + } else if (type == (PyObject *)&PyBool_Type) { + result = s_cbor_encoder_write_pyobject_as_bool(encoder, py_object); + } else if (type == (PyObject *)&PyBytes_Type) { + result = s_cbor_encoder_write_pyobject_as_bytes(encoder, py_object); + } else if (PyType_IsSubtype((PyTypeObject *)type, &PyUnicode_Type)) { /* Allow subclasses of `str` */ - return s_cbor_encoder_write_pyobject_as_text(encoder, py_object); - } else if (PyList_Check(py_object)) { + result = s_cbor_encoder_write_pyobject_as_text(encoder, py_object); + } else if (PyType_IsSubtype((PyTypeObject *)type, &PyList_Type)) { /* Write py_list, allow subclasses of `list` */ - return s_cbor_encoder_write_pylist(encoder, py_object); - } else if (PyDict_Check(py_object)) { + result = s_cbor_encoder_write_pylist(encoder, py_object); + } else if (PyType_IsSubtype((PyTypeObject *)type, &PyDict_Type)) { /* Write py_dict, allow subclasses of `dict` */ - return s_cbor_encoder_write_pydict(encoder, py_object); - } else if (py_object == Py_None) { - aws_cbor_encoder_write_null(encoder); + result = s_cbor_encoder_write_pydict(encoder, py_object); } else { - PyErr_Format(PyExc_ValueError, "Not supported type %R", (PyObject *)Py_TYPE(py_object)); + /* Check for datetime using stable ABI (slower, so checked last) */ + bool is_datetime = false; + if (aws_py_is_datetime_instance(py_object, &is_datetime) != AWS_OP_SUCCESS) { + /* Error occurred during datetime check */ + result = NULL; + } else if (is_datetime) { + /* Convert datetime to CBOR epoch time (tag 1) */ + /* Call timestamp() method - PyObject_CallMethod is more idiomatic and compatible with Python 3.8+ */ + PyObject *timestamp = PyObject_CallMethod(py_object, "timestamp", NULL); + if (timestamp) { + /* Write CBOR tag 1 (epoch time) + timestamp */ + aws_cbor_encoder_write_tag(encoder, AWS_CBOR_TAG_EPOCH_TIME); + result = s_cbor_encoder_write_pyobject_as_float(encoder, timestamp); + Py_DECREF(timestamp); + } else { + result = NULL; /* timestamp() call failed */ + } + } else { + /* Unsupported type */ + PyErr_Format(PyExc_ValueError, "Not supported type %R", type); + } } - Py_RETURN_NONE; + /* Release the type reference */ + Py_DECREF(type); + return result; } /*********************************** BINDINGS ***********************************************/ diff --git a/source/module.c b/source/module.c index 6da50b3d8..356be886f 100644 --- a/source/module.c +++ b/source/module.c @@ -36,6 +36,51 @@ static struct aws_logger s_logger; static bool s_logger_init = false; +/******************************************************************************* + * DateTime Type Cache (for stable ABI compatibility) + ******************************************************************************/ +static PyObject *s_datetime_class = NULL; + +static int s_init_datetime_cache(void) { + if (s_datetime_class) { + return AWS_OP_SUCCESS; /* Already initialized */ + } + + /* Import datetime module */ + PyObject *datetime_module = PyImport_ImportModule("datetime"); + if (!datetime_module) { + /* Python exception already set by PyImport_ImportModule */ + return aws_py_raise_error(); + } + + /* Get datetime class - new reference we'll keep */ + s_datetime_class = PyObject_GetAttrString(datetime_module, "datetime"); + Py_DECREF(datetime_module); + + if (!s_datetime_class) { + /* Python exception already set by PyObject_GetAttrString */ + return aws_py_raise_error(); + } + + return AWS_OP_SUCCESS; +} + +int aws_py_is_datetime_instance(PyObject *obj, bool *out_is_datetime) { + AWS_ASSERT(out_is_datetime); + + if (!s_datetime_class && s_init_datetime_cache() != AWS_OP_SUCCESS) { + return AWS_OP_ERR; + } + + int result = PyObject_IsInstance(obj, s_datetime_class); + if (result < 0) { + return aws_py_raise_error(); /* PyObject_IsInstance failed */ + } + + *out_is_datetime = (result != 0); + return AWS_OP_SUCCESS; +} + PyObject *aws_py_init_logging(PyObject *self, PyObject *args) { (void)self; @@ -1035,6 +1080,12 @@ PyMODINIT_FUNC PyInit__awscrt(void) { aws_register_error_info(&s_error_list); s_error_map_init(); + /* Initialize datetime type cache for stable ABI datetime support */ + if (s_init_datetime_cache() < 0) { + /* Non-fatal: datetime encoding will fail but rest of module works */ + PyErr_Clear(); + } + return m; } diff --git a/source/module.h b/source/module.h index 2f9dd217e..52d9165d7 100644 --- a/source/module.h +++ b/source/module.h @@ -70,6 +70,15 @@ struct aws_byte_cursor aws_byte_cursor_from_pyunicode(PyObject *str); * If conversion cannot occur, cursor->ptr will be NULL and a python exception is set */ struct aws_byte_cursor aws_byte_cursor_from_pybytes(PyObject *py_bytes); +/** + * Check if a PyObject is an instance of datetime.datetime using stable ABI. + * + * @param obj PyObject to check + * @param out_is_datetime Pointer to store result (true if datetime, false otherwise) + * @return AWS_OP_SUCCESS on success, AWS_OP_ERR on error (Python exception set) + */ +int aws_py_is_datetime_instance(PyObject *obj, bool *out_is_datetime); + /* Set current thread's error indicator based on aws_last_error() */ void PyErr_SetAwsLastError(void); diff --git a/test/test_cbor.py b/test/test_cbor.py index 3ab65b569..df055ea25 100644 --- a/test/test_cbor.py +++ b/test/test_cbor.py @@ -151,6 +151,86 @@ def on_epoch_time(epoch_secs): exception = e self.assertIsNotNone(exception) + def test_cbor_encode_decode_datetime(self): + """Test automatic datetime encoding/decoding""" + # Create a datetime object + dt = datetime.datetime(2024, 1, 1, 12, 0, 0) + + # Encode datetime - should automatically convert to CBOR tag 1 + timestamp + encoder = AwsCborEncoder() + encoder.write_data_item(dt) + + # Decode with callback to convert back to datetime + def on_epoch_time(epoch_secs): + return datetime.datetime.fromtimestamp(epoch_secs) + + decoder = AwsCborDecoder(encoder.get_encoded_data(), on_epoch_time) + result = decoder.pop_next_data_item() + + # Verify the result matches original datetime + self.assertEqual(dt, result) + self.assertIsInstance(result, datetime.datetime) + + # Test datetime with microsecond precision (milliseconds) + dt_with_microseconds = datetime.datetime(2024, 1, 1, 12, 0, 0, 123456) # 123.456 milliseconds + encoder3 = AwsCborEncoder() + encoder3.write_data_item(dt_with_microseconds) + + decoder3 = AwsCborDecoder(encoder3.get_encoded_data(), on_epoch_time) + result_microseconds = decoder3.pop_next_data_item() + + # Verify microsecond precision is preserved + self.assertEqual(dt_with_microseconds, result_microseconds) + self.assertEqual(dt_with_microseconds.microsecond, result_microseconds.microsecond) + self.assertIsInstance(result_microseconds, datetime.datetime) + + # Test datetime in a list + encoder2 = AwsCborEncoder() + test_list = [dt, "text", 123, dt_with_microseconds] + encoder2.write_data_item(test_list) + + decoder2 = AwsCborDecoder(encoder2.get_encoded_data(), on_epoch_time) + result_list = decoder2.pop_next_data_item() + + self.assertEqual(len(result_list), 4) + self.assertEqual(result_list[0], dt) + self.assertEqual(result_list[1], "text") + self.assertEqual(result_list[2], 123) + self.assertEqual(result_list[3], dt_with_microseconds) + # Verify microsecond precision in list + self.assertEqual(result_list[3].microsecond, 123456) + + def test_cbor_encode_unsupported_type(self): + """Test that encoding unsupported types raises ValueError""" + # Create a custom class that's not supported by CBOR encoder + class CustomClass: + def __init__(self, value): + self.value = value + + # Try to encode an unsupported type + encoder = AwsCborEncoder() + unsupported_obj = CustomClass(42) + + # Should raise ValueError with message about unsupported type + with self.assertRaises(ValueError) as context: + encoder.write_data_item(unsupported_obj) + # Verify the error message mentions "Not supported type" + self.assertIn("Not supported type", str(context.exception)) + + # Test unsupported type in a list (should also fail) + encoder2 = AwsCborEncoder() + with self.assertRaises(ValueError) as context2: + encoder2.write_data_item([1, 2, unsupported_obj, 3]) + + self.assertIn("Not supported type", str(context2.exception)) + + # Test unsupported type as dict key (should also fail) + encoder3 = AwsCborEncoder() + with self.assertRaises(ValueError) as context3: + encoder3.write_data_item({unsupported_obj: "value"}) + + self.assertIn("Not supported type", str(context3.exception)) + def _ieee754_bits_to_float(self, bits): return struct.unpack('>f', struct.pack('>I', bits))[0]