# Copyright 2025-present Iyad
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import concurrent.futures as cf
import datetime
import ipaddress
import platform
import socket
import subprocess
import threading
import time
from collections import deque
from typing import TYPE_CHECKING, NamedTuple, NotRequired, TypedDict, Unpack
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Final, Literal
__all__ = (
"DecomposedSonyflake",
"InvalidBitsMachineID",
"InvalidBitsSequence",
"InvalidBitsTime",
"InvalidMachineID",
"InvalidSequence",
"InvalidTimeUnit",
"MachineIDCheckFailure",
"NoPrivateAddress",
"OverTimeLimit",
"Sonyflake",
"SonyflakeError",
"StartTimeAhead",
)
DEFAULT_BITS_TIME: Final = 39
DEFAULT_BITS_SEQUENCE: Final = 8
DEFAULT_BITS_MACHINE_ID: Final = 16
DEFAULT_TIME_UNIT: Final = 10_000_000 # 10msec
SECOND_NS: Final = 1_000_000_000
[docs]
class SonyflakeError(Exception):
"""Base class for all sonyflake errors."""
[docs]
class InvalidBitsSequence(SonyflakeError):
"""Raised when the bit length for the sequence is out of valid range (0-30)."""
def __init__(self) -> None:
msg = "bit length for sequence number must be between 0 and 30 (inclusive)."
super().__init__(msg)
[docs]
class InvalidBitsMachineID(SonyflakeError):
"""Raised when the bit length for the machine ID is out of valid range (0-30)."""
def __init__(self) -> None:
msg = "bit length for machine id must be between 0 and 30 (inclusive)."
super().__init__(msg)
[docs]
class InvalidBitsTime(SonyflakeError):
"""Raised when the computed time bit length is too small to represent valid timestamps."""
def __init__(self) -> None:
msg = "bit length for time must be at least 32."
super().__init__(msg)
[docs]
class InvalidTimeUnit(SonyflakeError):
"""Raised when the provided time unit is too small."""
def __init__(self) -> None:
msg = "time unit must be atleast 1 millisecond."
super().__init__(msg)
[docs]
class InvalidSequence(SonyflakeError):
"""Raised when the sequence number is out of valid range."""
def __init__(self, bits_sequence: int) -> None:
max_value = (1 << bits_sequence) - 1
msg = f"sequence number must be between 0 and {max_value} (inclusive) for bits_sequence={bits_sequence}."
super().__init__(msg)
[docs]
class InvalidMachineID(SonyflakeError):
"""Raised when the computed machine ID is out of range."""
def __init__(self, bits_machine_id: int) -> None:
max_value = (1 << bits_machine_id) - 1
msg = f"machine id must be between 0 and {max_value} (inclusive) for bits_machine_id={bits_machine_id}"
super().__init__(msg)
[docs]
class MachineIDCheckFailure(SonyflakeError):
"""Raised when a machine ID fails the validation check.
.. versionadded:: 2.0
"""
def __init__(self, machine_id: int) -> None:
msg = f"machine id validation failed for {machine_id=}."
super().__init__(msg)
[docs]
class StartTimeAhead(SonyflakeError):
"""Raised when the provided start time is ahead of the current time."""
def __init__(self) -> None:
msg = "start time must not be in the future."
super().__init__(msg)
[docs]
class OverTimeLimit(SonyflakeError):
"""Raised when the elapsed time exceeds the representable limit."""
def __init__(self, max_elapsed_time: int) -> None:
msg = f"elapsed time exceeded the limit: max allowed is {max_elapsed_time}"
super().__init__(msg)
[docs]
class NoPrivateAddress(SonyflakeError):
"""Raised when no private IPv4 address could be determined."""
def __init__(self) -> None:
msg = "failed to determine private IPv4 address."
super().__init__(msg)
[docs]
class DecomposedSonyflake(NamedTuple):
"""Represents a decomposed Sonyflake.
This structure holds the individual components of a 64-bit Sonyflake ID.
.. versionchanged:: 2.0
The ``id`` attribute was removed.
Attributes
----------
time : int
The time portion of the ID.
sequence : int
The sequence number portion of the ID.
machine_id : int
The machine identifier portion of the ID.
"""
time: int
sequence: int
machine_id: int
def _pick_private_ip(ips: list[str]) -> ipaddress.IPv4Address:
for ip_str in ips:
ip = ipaddress.IPv4Address(ip_str)
if ip.is_loopback:
continue
if ip.is_private or ip.is_link_local:
return ip
raise NoPrivateAddress
def _macos_reliable_ip() -> ipaddress.IPv4Address:
result = subprocess.run(["/usr/bin/osascript", "-e", "IPv4 address of (system info)"], capture_output=True, check=True)
ip_str = result.stdout.decode("utf-8").removesuffix("\n")
return ipaddress.IPv4Address(ip_str)
def _lower_16bit_private_ip() -> int:
# socket.gethostbyname_ex(socket.getfqdn()) fails on some macbooks (tbh, most from what i have tested).
# reference: https://github.com/python/cpython/issues/79345
try:
*_, ips = socket.gethostbyname_ex(socket.getfqdn())
ip = _pick_private_ip(ips)
except socket.gaierror:
if platform.system() != "Darwin":
raise
ip = _macos_reliable_ip()
ip_bytes = ip.packed
return (ip_bytes[2] << 8) + ip_bytes[3]
def _utcnow() -> datetime.datetime:
return datetime.datetime.now(datetime.UTC)
# Thanks to the discussion with Sinbad :).
class _HybridLock:
__slots__ = ("_internal_lock", "_locked", "_waiters")
_waiters: deque[cf.Future[None]]
_internal_lock: threading.Lock
_locked: bool
def __init__(self) -> None:
self._waiters = deque()
self._internal_lock = threading.Lock()
self._locked = False
async def __aenter__(self) -> None:
with self._internal_lock:
# Incase you are wondering why the all check
# https://discuss.python.org/t/preventing-yield-inside-certain-context-managers/1091
# https://peps.python.org/pep-0789/
if not self._locked and all(w.cancelled() for w in self._waiters):
self._locked = True
return
fut: cf.Future[None] = cf.Future()
self._waiters.append(fut)
try:
await asyncio.wrap_future(fut)
except (asyncio.CancelledError, cf.CancelledError):
with self._internal_lock:
if fut.done() and not fut.cancelled():
self._locked = False
self.__wake_next()
raise
async def __aexit__(self, *_: object) -> Literal[False]:
return self.__release()
def __enter__(self) -> None:
with self._internal_lock:
if not self._locked and all(w.cancelled() for w in self._waiters):
self._locked = True
return
fut: cf.Future[None] = cf.Future()
self._waiters.append(fut)
try:
fut.result()
except cf.CancelledError:
with self._internal_lock:
if fut.done() and not fut.cancelled():
self._locked = False
self.__wake_next()
raise
def __exit__(self, *_: object) -> Literal[False]:
return self.__release()
def __release(self) -> Literal[False]:
with self._internal_lock:
self._locked = False
self.__wake_next()
return False
def __wake_next(self) -> None:
# Not acquiring the lock here, since this function is gonna be called
# from a locked context.
while self._waiters:
fut = self._waiters.popleft()
if not fut.cancelled() and not fut.done():
self._locked = True
fut.set_result(None)
break
class _SonyflakeOptions(TypedDict):
bits_sequence: NotRequired[int]
bits_machine_id: NotRequired[int]
time_unit: NotRequired[datetime.timedelta]
start_time: datetime.datetime
machine_id: NotRequired[int]
check_machine_id: NotRequired[Callable[[int], bool]]
[docs]
class Sonyflake:
"""A distributed unique ID generator.
Parameters
----------
bits_sequence : int, optional
Number of bits allocated for the sequence number (the default is `8`).
bits_machine_id : int, optional
Number of bits allocated for the machine ID (the default is `16`).
time_unit : datetime.timedelta, optional
Minimum time unit used for incrementing IDs (the default is 10 milliseconds).
start_time : datetime.datetime
The custom epoch from which time is measured.
machine_id : int, optional
Custom machine ID to use (the default is the lower 16 bits of the machine's private IP address).
check_machine_id : Callable[[int], bool], optional
Function to validate the generated or provided machine ID (the default is `None`, which disables validation).
Raises
------
ValueError
If start time is not provided.
InvalidBitsSequence
If the provided bit length for the sequence number is invalid.
InvalidBitsMachineID
If the provided bit length for the machine ID is invalid.
InvalidTimeUnit
If the time unit is smaller than 1 millisecond.
InvalidMachineID
If the provided or generated machine ID is invalid.
StartTimeAhead
If the start time is set in the future.
"""
__slots__ = (
"_bits_machine_id",
"_bits_sequence",
"_bits_time",
"_elapsed_time",
"_lock",
"_machine_id",
"_sequence",
"_start_time",
"_time_unit",
)
_bits_machine_id: int
_bits_sequence: int
_bits_time: int
_elapsed_time: int
_lock: _HybridLock
_machine_id: int
_sequence: int
_start_time: int
_time_unit: int
def __init__(self, **options: Unpack[_SonyflakeOptions]) -> None:
bits_sequence = options.pop("bits_sequence", DEFAULT_BITS_SEQUENCE)
if not 0 <= bits_sequence <= 30:
raise InvalidBitsSequence
bits_machine_id = options.pop("bits_machine_id", DEFAULT_BITS_MACHINE_ID)
if not 0 <= bits_machine_id <= 30:
raise InvalidBitsMachineID
bits_time = 63 - bits_sequence - bits_machine_id
if bits_time < 32:
raise InvalidBitsTime
self._bits_sequence = bits_sequence
self._bits_machine_id = bits_machine_id
self._bits_time = bits_time
try:
time_unit = options.pop("time_unit")
except KeyError:
self._time_unit = DEFAULT_TIME_UNIT
else:
if time_unit < datetime.timedelta(milliseconds=1):
raise InvalidTimeUnit
self._time_unit = int(time_unit.total_seconds() * SECOND_NS)
try:
start_time = options["start_time"]
except KeyError:
msg = "'start_time' is required"
raise ValueError(msg) from None
else:
start_time = start_time.astimezone(datetime.UTC)
if start_time > _utcnow():
raise StartTimeAhead
self._start_time = self._to_internal_time(start_time)
self._elapsed_time = 0
self._sequence = (1 << self._bits_sequence) - 1
try:
machine_id = options.pop("machine_id")
except KeyError:
machine_id = _lower_16bit_private_ip()
if not 0 <= machine_id < (1 << bits_machine_id):
raise InvalidMachineID(bits_machine_id)
try:
check_machine_id = options.pop("check_machine_id")
except KeyError:
pass
else:
if not check_machine_id(machine_id):
raise MachineIDCheckFailure(machine_id)
self._machine_id = machine_id
self._lock = _HybridLock()
[docs]
def next_id(self) -> int:
"""Return the next unique id.
Returns
-------
int
A 64-bit Sonyflake ID.
Raises
------
OverTimeLimit
If the elapsed time exceeds the maximum representable value.
"""
mask_sequence = (1 << self._bits_sequence) - 1
with self._lock:
current = self._current_elapsed_time()
if self._elapsed_time < current:
self._elapsed_time = current
self._sequence = 0
else:
self._sequence = (self._sequence + 1) & mask_sequence
if self._sequence == 0:
self._elapsed_time += 1
overtime = self._elapsed_time - current
self._sleep(overtime)
return self._to_id()
[docs]
async def next_id_async(self) -> int:
"""Return the next unique id.
This coroutine is the asynchronous version of :meth:`Sonyflake.next_id`
and is intended for use in asynchronous applications.
.. versionadded:: 2.0
Returns
-------
int
A 64-bit Sonyflake ID.
Raises
------
OverTimeLimit
If the elapsed time exceeds the maximum representable value.
"""
mask_sequence = (1 << self._bits_sequence) - 1
async with self._lock:
current = self._current_elapsed_time()
if self._elapsed_time < current:
self._elapsed_time = current
self._sequence = 0
else:
self._sequence = (self._sequence + 1) & mask_sequence
if self._sequence == 0:
self._elapsed_time += 1
overtime = self._elapsed_time - current
await self._sleep_async(overtime)
return self._to_id()
[docs]
def to_time(self, sonyflake_id: int) -> datetime.datetime:
"""Convert a Sonyflake ID to its corresponding UTC datetime.
Parameters
----------
sonyflake_id : int
The Sonyflake ID to convert.
Returns
-------
datetime.datetime
The UTC datetime corresponding to the given ID.
"""
ns = (self._start_time + self._time_part(sonyflake_id)) * self._time_unit
return datetime.datetime.fromtimestamp(ns / SECOND_NS, tz=datetime.UTC)
[docs]
def compose(self, dt: datetime.datetime, sequence: int, machine_id: int) -> int:
"""Compose a Sonyflake ID from datetime, sequence, and machine ID.
.. versionchanged:: 2.0
The ``dt`` parameter no longer needs to be in UTC and maybe
timezone-aware or naieve. If ``dt`` is naive (has no timezone),
it is assumed to be in the system local timezone.
Parameters
----------
dt : datetime.datetime
The datetime at which the ID is generated.
sequence : int
A number between 0 and 2^bits_sequence - 1 (inclusive).
machine_id : int
A number between 0 and 2^bits_machine_id - 1 (inclusive).
Returns
-------
int
The composed Sonyflake ID.
Raises
------
StartTimeAhead
If the datetime is before the configured start time.
OverTimeLimit
If the elapsed time exceeds the representable range.
InvalidSequence
If the sequence value is out of range.
InvalidMachineID
If the machine ID is out of range.
"""
elapsed_time = self._to_internal_time(dt) - self._start_time
if elapsed_time < 0:
raise StartTimeAhead
max_elapsed_time = (1 << self._bits_time) - 1
if elapsed_time > max_elapsed_time:
raise OverTimeLimit(max_elapsed_time)
if not 0 <= sequence < (1 << self._bits_sequence):
raise InvalidSequence(self._bits_sequence)
if not 0 <= machine_id < (1 << self._bits_machine_id):
raise InvalidMachineID(self._bits_machine_id)
time = elapsed_time << (self._bits_sequence + self._bits_machine_id)
seq = sequence << self._bits_machine_id
return time | seq | machine_id
[docs]
def decompose(self, sonyflake_id: int) -> DecomposedSonyflake:
"""Decompose a Sonyflake ID into its components.
Parameters
----------
sonyflake_id : int
The Sonyflake ID to decompose.
Returns
-------
DecomposedSonyflake
A named tuple with the fields: `time`, `sequence`, `machine_id`.
"""
time = self._time_part(sonyflake_id)
sequence = self._sequence_part(sonyflake_id)
machine_id = self._machine_id_part(sonyflake_id)
return DecomposedSonyflake(
time=time,
sequence=sequence,
machine_id=machine_id,
)
def _to_internal_time(self, dt: datetime.datetime) -> int:
dt = dt.astimezone(tz=datetime.UTC)
unix_ns = int(dt.timestamp() * SECOND_NS)
return unix_ns // self._time_unit
def _current_elapsed_time(self) -> int:
return self._to_internal_time(_utcnow()) - self._start_time
def _to_id(self) -> int:
max_elapsed_time = (1 << self._bits_time) - 1
if self._elapsed_time > max_elapsed_time:
raise OverTimeLimit(max_elapsed_time)
time = self._elapsed_time << (self._bits_sequence + self._bits_machine_id)
sequence = self._sequence << self._bits_machine_id
return time | sequence | self._machine_id
def _time_part(self, sonyflake_id: int) -> int:
return sonyflake_id >> (self._bits_sequence + self._bits_machine_id)
def _sequence_part(self, sonyflake_id: int) -> int:
mask_sequence = ((1 << self._bits_sequence) - 1) << self._bits_machine_id
return (sonyflake_id & mask_sequence) >> self._bits_machine_id
def _machine_id_part(self, sonyflake_id: int) -> int:
mask_machine_id = (1 << self._bits_machine_id) - 1
return sonyflake_id & mask_machine_id
def _sleep(self, overtime: int) -> None:
now_ns = int(_utcnow().timestamp() * SECOND_NS)
sleep_ns = (overtime * self._time_unit) - (now_ns % self._time_unit)
time.sleep(sleep_ns / SECOND_NS)
async def _sleep_async(self, overtime: int) -> None:
now_ns = int(_utcnow().timestamp() * SECOND_NS)
sleep_ns = (overtime * self._time_unit) - (now_ns % self._time_unit)
await asyncio.sleep(sleep_ns / SECOND_NS)
def __repr__(self) -> str:
start_time = str(datetime.datetime.fromtimestamp((self._start_time * self._time_unit) / SECOND_NS, tz=datetime.UTC))
elapsed_time = str(datetime.timedelta(seconds=(self._elapsed_time * self._time_unit) / SECOND_NS))
time_unit = str(datetime.timedelta(seconds=self._time_unit / SECOND_NS))
return (
f"{self.__class__.__name__}("
f"bits_machine_id={self._bits_machine_id!r}, "
f"bits_sequence={self._bits_sequence!r}, "
f"bits_time={self._bits_time!r}, "
f"elapsed_time={elapsed_time!r}, "
f"machine_id={self._machine_id!r}, "
f"sequence={self._sequence!r}, "
f"start_time={start_time!r}, "
f"time_unit={time_unit!r})"
)