Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add utf8 safeguards, fix chr method, add unicode and utf16 parsing for String #3239

Draft
wants to merge 30 commits into
base: nightly
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7e4f0df
add better safeguards and fix chr method
martinvuyk Jul 13, 2024
7134f8f
update changelog
martinvuyk Jul 13, 2024
ab84608
rename to from_unicode
martinvuyk Jul 13, 2024
53d7038
move from_unicode to be static method
martinvuyk Jul 13, 2024
6d480b7
fix from_unicode
martinvuyk Jul 13, 2024
5236388
fix docstring
martinvuyk Jul 13, 2024
439aa21
fix indentation
martinvuyk Jul 13, 2024
c6f2dfb
fix list constructor
martinvuyk Jul 13, 2024
20bf017
fix use less lines
martinvuyk Jul 13, 2024
9a62b42
add utf16 decode
martinvuyk Jul 13, 2024
0bbc386
fix changelog
martinvuyk Jul 13, 2024
74e698b
fix detail
martinvuyk Jul 13, 2024
bf4093d
fix detail
martinvuyk Jul 13, 2024
5a2af26
fix detail
martinvuyk Jul 13, 2024
30c027f
fix detail
martinvuyk Jul 13, 2024
ddcbf0d
fix detail
martinvuyk Jul 13, 2024
9f5ee3b
simplify utf16 internals
martinvuyk Jul 13, 2024
fcc789c
fix detail
martinvuyk Jul 13, 2024
e08bc57
fix detail
martinvuyk Jul 13, 2024
9ffd5e6
fix detail
martinvuyk Jul 14, 2024
afb537a
fix detail
martinvuyk Jul 14, 2024
805041e
fix detail
martinvuyk Jul 14, 2024
0fcdf50
fix detail
martinvuyk Jul 14, 2024
be5a203
fix detail
martinvuyk Jul 14, 2024
fccdbcd
fix detail
martinvuyk Jul 14, 2024
f46ce80
add suggestion from @mzaks
martinvuyk Jul 14, 2024
6b47694
fix use unsafe_get
martinvuyk Jul 16, 2024
ca38ca3
Merge remote-tracking branch 'upstream/nightly' into add-utf8-safeguards
martinvuyk Jul 16, 2024
af3be58
use variant for unicode parsing
martinvuyk Jul 16, 2024
a4eedb0
Merge remote-tracking branch 'upstream/nightly' into add-utf8-safeguards
martinvuyk Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ and deprecated their private `_byte_length()` methods. Added a warning to
future and `StringSlice.__len__` now does return the Unicode codepoints length.
([PR #2960](https://github.com/modularml/mojo/pull/2960) by [@martinvuyk](https://github.com/martinvuyk))

- Added `String.from_unicode(values: List[Int]) -> String` and
`String.from_utf16(values: List[UInt16]) -> String` functions that return a
String containing the concatenated characters. If a Unicode codepoint
is invalid, the parsed String has a replacement character (�) in that index.
`fn chr(c: Int) -> String` function now returns a replacement character (�)
if the Unicode codepoint is invalid.
([PR #3239](https://github.com/modularml/mojo/pull/3239) by [@martinvuyk](https://github.com/martinvuyk))

- Added new `StaticString` type alias. This can be used in place of
`StringLiteral` for runtime string arguments.

Expand Down
234 changes: 176 additions & 58 deletions stdlib/src/builtin/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from memory import UnsafePointer, memcmp, memcpy

from utils import Span, StaticIntTuple, StringRef, StringSlice
from utils._format import Formattable, Formatter, ToFormatter
from utils.string_slice import _utf8_byte_type, _StringSliceIter
from utils.string_slice import _utf8_byte_type, _StringSliceIter, _is_valid_utf8

# ===----------------------------------------------------------------------=== #
# ord
Expand Down Expand Up @@ -97,68 +97,75 @@ fn ord(s: StringSlice) -> Int:
# ===----------------------------------------------------------------------=== #


fn chr(c: Int) -> String:
"""Returns a string based on the given Unicode code point.
fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int:
alias sizes = SIMD[DType.int32, 4](0, 0b0111_1111, 0b0111_1111_1111, 0xFFFF)
return int((sizes < c).cast[DType.uint8]().reduce_add())

Returns the string representing a character whose code point is the integer
`c`. For example, `chr(97)` returns the string `"a"`. This is the inverse of
the `ord()` function.

Args:
c: An integer that represents a code point.
fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int):
"""Shift unicode to utf8 representation.

Returns:
A string containing a single character based on the given code point.
Unicode (represented as UInt32 BE) to UTF-8 conversion :
- 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa
- a
- 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb
- (a >> 6) | 0b11000000, b | 0b10000000
- 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc
- (a >> 12) | 0b11100000, (b >> 6) | 0b10000000, c | 0b10000000
- 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc
10dddddd
- (a >> 18) | 0b11110000, (b >> 12) | 0b10000000, (c >> 6) | 0b10000000,
d | 0b10000000
"""
# Unicode (represented as UInt32 BE) to UTF-8 conversion :
# 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa a
# 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb a >> 6 | 0b11000000, b | 0b10000000
# 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc a >> 12 | 0b11100000, b >> 6 | 0b10000000, c | 0b10000000
# 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc 10dddddd a >> 18 | 0b11110000, b >> 12 | 0b10000000, c >> 6 | 0b10000000, d | 0b10000000

if (c >> 7) == 0: # This is 1 byte ASCII char
return _chr_ascii(c)

@always_inline
fn _utf8_len(val: Int) -> Int:
debug_assert(
0 <= val <= 0x10FFFF, "Value is not a valid Unicode code point"
)
alias sizes = SIMD[DType.int32, 4](
0, 0b1111_111, 0b1111_1111_111, 0b1111_1111_1111_1111
)
var values = SIMD[DType.int32, 4](val)
var mask = values > sizes
return int(mask.cast[DType.uint8]().reduce_add())
if num_bytes == 1:
ptr[0] = UInt8(c)
return

var num_bytes = _utf8_len(c)
var p = UnsafePointer[UInt8].alloc(num_bytes + 1)
var shift = 6 * (num_bytes - 1)
var mask = UInt8(0xFF) >> (num_bytes + 1)
var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes)
Scalar.store(p, ((c >> shift) & mask) | num_bytes_marker)
ptr[0] = ((c >> shift) & mask) | num_bytes_marker
for i in range(1, num_bytes):
shift -= 6
Scalar.store(p, i, ((c >> shift) & 0b00111111) | 0b10000000)
Scalar.store(p, num_bytes, 0)
return String(p.bitcast[UInt8](), num_bytes + 1)


# ===----------------------------------------------------------------------=== #
# ascii
# ===----------------------------------------------------------------------=== #
ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000


fn _chr_ascii(c: UInt8) -> String:
"""Returns a string based on the given ASCII code point.
fn chr(c: Int) -> String:
"""Returns a String based on the given Unicode code point. This is the
inverse of the `ord()` function.

Args:
c: An integer that represents a code point.

Returns:
A string containing a single character based on the given code point.
A String containing a single character based on the given code point. If
the Unicode codepoint is invalid, a replacement char (�) is returned.

Examples:
```mojo
print(chr(97)) # "a"
print(chr(0x10FFFF + 1)) # "�"
```
.
"""
return String(String._buffer_type(c, 0))
if c < 0b1000_0000: # 1 byte ASCII char
return String(String._buffer_type(c, 0))

var num_bytes = _unicode_codepoint_utf8_byte_length(c)
var p = UnsafePointer[UInt8].alloc(num_bytes + 1)
_shift_unicode_to_utf8(p, c, num_bytes)
if not _is_valid_utf8(p, num_bytes):
debug_assert(False, "Invalid Unicode code point")
p.free()
return chr(0xFFFD)
p[num_bytes] = 0
return String(ptr=p, len=num_bytes + 1)


# ===----------------------------------------------------------------------=== #
# ascii
# ===----------------------------------------------------------------------=== #


fn _repr_ascii(c: UInt8) -> String:
Expand All @@ -178,7 +185,7 @@ fn _repr_ascii(c: UInt8) -> String:
if c == ord_back_slash:
return r"\\"
elif isprintable(c):
return _chr_ascii(c)
return String(String._buffer_type(c, 0))
elif c == ord_tab:
return r"\t"
elif c == ord_new_line:
Expand Down Expand Up @@ -746,21 +753,23 @@ struct String(

@always_inline
fn __init__(inout self, owned impl: List[UInt8]):
"""Construct a string from a buffer of bytes.
"""Construct a string from a buffer of bytes. The buffer must be
terminated with a null byte.

The buffer must be terminated with a null byte:
Args:
impl: The buffer.

Examples:
```mojo
var buf = List[UInt8]()
buf.append(ord('H'))
buf.append(ord('i'))
buf.append(0)
var hi = String(buf)
print(String(List[UInt8](72, 105, 0))) # Hi
```

Args:
impl: The buffer.
.
"""
# TODO(#933): use when llvm intrinsics can be used at compile time
# debug_assert(
# _is_valid_utf8(impl.unsafe_ptr(), len(impl)),
# "String doesn't have valid UTF-8 encoding",
# )
debug_assert(
impl[-1] == 0,
"expected last element of String buffer to be null terminator",
Expand Down Expand Up @@ -838,9 +847,7 @@ struct String(
# we don't know the capacity of ptr, but we'll assume it's the same or
# larger than len
self = Self(
Self._buffer_type(
unsafe_pointer=ptr.bitcast[UInt8](), size=len, capacity=len
)
Self._buffer_type(unsafe_pointer=ptr, size=len, capacity=len)
)

fn __init__(inout self, obj: PythonObject):
Expand Down Expand Up @@ -1412,6 +1419,11 @@ struct String(
# FIXME(MSTDL-160):
# Enforce UTF-8 encoding in String so this is actually
# guaranteed to be valid.
# TODO(#933): use when llvm intrinsics can be used at compile time
# debug_assert(
# _is_valid_utf8(self.unsafe_ptr(), self.byte_length()),
# "String doesn't have valid UTF-8 encoding",
# )
return StringSlice(unsafe_from_utf8=self.as_bytes_slice())

@always_inline
Expand Down Expand Up @@ -2160,6 +2172,112 @@ struct String(
return False
return True

@staticmethod
fn from_unicode(values: Variant[List[Int], List[UInt32]]) -> String:
"""Returns a String based on the given Unicode code points.

Args:
values: A List of Unicode code points.

Returns:
A String containing the concatenated characters. If a Unicode
codepoint is invalid, the parsed String has a replacement character
(�) in that index.

Examples:
```mojo
print(String.from_unicode(List[Int](97, 97, 0x10FFFF + 1, 97))) # "aa�a"
```

Notes:
This method allocates `4 * len(values)` bytes.
"""

var buf_length = len(values.unsafe_get[List[Int]]()[]) if (
values.isa[List[Int]]()
) else len(values.unsafe_get[List[UInt32]]()[])
var max_len = 4 * buf_length
var ptr = UnsafePointer[UInt8].alloc(max_len)
var current_offset = 0
for i in range(buf_length):
var c = values.unsafe_get[List[Int]]()[].unsafe_get(i) if (
values.isa[List[Int]]()
) else int(values.unsafe_get[List[UInt32]]()[].unsafe_get(i))
var num_bytes = _unicode_codepoint_utf8_byte_length(c)
var curr_ptr = ptr.offset(current_offset)
_shift_unicode_to_utf8(curr_ptr, c, num_bytes)
if not _is_valid_utf8(curr_ptr, num_bytes):
debug_assert(False, "Invalid Unicode value at index: " + str(i))
num_bytes = 3
_shift_unicode_to_utf8(curr_ptr, 0xFFFD, num_bytes)
current_offset += num_bytes
var length = current_offset + 1
var buf = List[UInt8](unsafe_pointer=ptr, size=length, capacity=max_len)
buf[current_offset] = 0
return String(buf^)

@staticmethod
fn from_utf16(values: List[UInt16]) -> String:
"""Returns a String based on the given UTF-16 values.

Args:
values: A List of UTF-16 values.

Returns:
A String containing the concatenated characters. If a Unicode
codepoint is invalid, the parsed String has a replacement character
(�) in that index.

Examples:
```mojo
print(String.from_utf16(List[UInt16](97, 97, 0xD800, 97))) # "aa�a"
```

Notes:
This method allocates `2 * len(values)` bytes.
"""

var max_len = 2 * len(values)
var ptr = UnsafePointer[UInt8].alloc(max_len)
var current_offset = 0
var values_idx = 0

while values_idx < len(values):
var curr_ptr = ptr.offset(current_offset)
var c = int(values.unsafe_get(values_idx))
var num_bytes: Int

if c < 0b1000_0000: # ASCII
num_bytes = 1
elif c < 0x8_00: # 2 byte long sequence
num_bytes = 2
elif c < 0xD8_00 or c >= 0xE0_00: # 3 byte long sequence
num_bytes = 3
else: # 4 byte long sequence
if values_idx + 1 >= len(values):
num_bytes = 1
c = 0xFF
else:
num_bytes = 4
alias low_10b = 0b0011_1111_1111 # get lower 10 bits
var c2 = int(values.unsafe_get(values_idx + 1))
c = 2**16 + ((c & low_10b) << 10) | (c2 & low_10b)

_shift_unicode_to_utf8(curr_ptr, c, num_bytes)
if not _is_valid_utf8(curr_ptr, num_bytes):
debug_assert(
False, "Invalid UTF-16 value at index: " + str(values_idx)
)
num_bytes = 3
_shift_unicode_to_utf8(curr_ptr, 0xFFFD, num_bytes)

current_offset += num_bytes
values_idx += 1 if num_bytes < 4 else 2
var length = current_offset + 1
var buf = List[UInt8](unsafe_pointer=ptr, size=length, capacity=max_len)
buf[current_offset] = 0
return String(buf^)


# ===----------------------------------------------------------------------=== #
# Utilities
Expand Down
8 changes: 7 additions & 1 deletion stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from sys.ffi import C_char
from utils import StringRef
from utils._format import Formattable, Formatter
from utils._visualizers import lldb_formatter_wrapping_type
from utils.string_slice import _is_valid_utf8

from .string import _atol

Expand Down Expand Up @@ -225,7 +226,12 @@ struct StringLiteral(
var new_capacity = length + 1
buffer._realloc(new_capacity)
buffer.size = new_capacity
var data: UnsafePointer[UInt8] = self.unsafe_ptr()
var data = self.unsafe_ptr()
# TODO(#933): use when llvm intrinsics can be used at compile time
# debug_assert(
# _is_valid_utf8(data, length),
# "StringLiteral doesn't have valid UTF-8 encoding",
# )
memcpy(buffer.data, data, length)
(buffer.data + length).init_pointee_move(0)
string._buffer = buffer^
Expand Down
14 changes: 14 additions & 0 deletions stdlib/test/builtin/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@ def test_chr():
assert_equal("α", chr(945))
assert_equal("➿", chr(10175))
assert_equal("🔥", chr(128293))
assert_equal("�", chr(0xFFFD))
assert_equal("�", chr(0x10FFFF + 1))


def test_unicode():
var a = List[Int](65, 97, 33, 945, 10175, 128293, 0xFFFD, 0x10FFFF + 1)
assert_equal("Aa!α➿🔥��", String.from_unicode(a))


def test_utf16():
var a = List[UInt16](65, 97, 33, 945, 10175, 0xD83D, 0xDD25, 0xFFFD, 0xD800)
assert_equal("Aa!α➿🔥��", String.from_utf16(a))


def test_string_indexing():
Expand Down Expand Up @@ -1435,6 +1447,8 @@ def main():
test_stringref_strip()
test_ord()
test_chr()
test_unicode()
test_utf16()
test_string_indexing()
test_atol()
test_atol_base_0()
Expand Down