Browse Source

Added TLS initialization and shutdown functions

Fixes https://github.com/libsdl-org/SDL/issues/8576

(cherry picked from commit b5170439368d1ea6f57220c725bdfeb99ad78e82)
(cherry picked from commit 551164812aeca86d4019f8c13a3c06a322481fc6)
Sam Lantinga 9 months ago
parent
commit
87ba287f81

+ 3 - 1
src/SDL.c

@@ -52,6 +52,7 @@
 #include "haptic/SDL_haptic_c.h"
 #include "joystick/SDL_joystick_c.h"
 #include "sensor/SDL_sensor_c.h"
+#include "thread/SDL_thread_c.h"
 
 /* Initialization/Cleanup routines */
 #ifndef SDL_TIMERS_DISABLED
@@ -189,6 +190,7 @@ int SDL_InitSubSystem(Uint32 flags)
         return SDL_SetError("Application didn't initialize properly, did you include SDL_main.h in the file containing your main() function?");
     }
 
+    SDL_InitTLSData();
     SDL_LogInit();
 
     /* Clear the error message */
@@ -522,7 +524,7 @@ void SDL_Quit(void)
      */
     SDL_memset(SDL_SubsystemRefCount, 0x0, sizeof(SDL_SubsystemRefCount));
 
-    SDL_TLSCleanup();
+    SDL_QuitTLSData();
 
     SDL_bInMainQuit = SDL_FALSE;
 }

+ 6 - 0
src/thread/SDL_systhread.h

@@ -59,12 +59,18 @@ extern void SDL_SYS_WaitThread(SDL_Thread *thread);
 /* Mark thread as cleaned up as soon as it exits, without joining. */
 extern void SDL_SYS_DetachThread(SDL_Thread *thread);
 
+/* Initialize the global TLS data */
+extern void SDL_SYS_InitTLSData(void);
+
 /* Get the thread local storage for this thread */
 extern SDL_TLSData *SDL_SYS_GetTLSData(void);
 
 /* Set the thread local storage for this thread */
 extern int SDL_SYS_SetTLSData(SDL_TLSData *data);
 
+/* Quit the global TLS data */
+extern void SDL_SYS_QuitTLSData(void);
+
 /* This is for internal SDL use, so we don't need #ifdefs everywhere. */
 extern SDL_Thread *
 SDL_CreateThreadInternal(int(SDLCALL *fn)(void *), const char *name,

+ 68 - 25
src/thread/SDL_thread.c

@@ -28,6 +28,15 @@
 #include "SDL_hints.h"
 #include "../SDL_error_c.h"
 
+/* The storage is local to the thread, but the IDs are global for the process */
+
+static SDL_atomic_t SDL_tls_allocated;
+
+void SDL_InitTLSData(void)
+{
+    SDL_SYS_InitTLSData();
+}
+
 SDL_TLSID SDL_TLSCreate(void)
 {
     static SDL_atomic_t SDL_tls_id;
@@ -53,6 +62,13 @@ int SDL_TLSSet(SDL_TLSID id, const void *value, void(SDLCALL *destructor)(void *
         return SDL_InvalidParamError("id");
     }
 
+    /* Make sure TLS is initialized.
+     * There's a race condition here if you are calling this from non-SDL threads
+     * and haven't called SDL_Init() on your main thread, but such is life.
+     */
+    SDL_InitTLSData();
+
+    /* Get the storage for the current thread */
     storage = SDL_SYS_GetTLSData();
     if (!storage || (id > storage->limit)) {
         unsigned int i, oldlimit, newlimit;
@@ -69,8 +85,10 @@ int SDL_TLSSet(SDL_TLSID id, const void *value, void(SDLCALL *destructor)(void *
             storage->array[i].destructor = NULL;
         }
         if (SDL_SYS_SetTLSData(storage) != 0) {
+            SDL_free(storage);
             return -1;
         }
+        SDL_AtomicIncRef(&SDL_tls_allocated);
     }
 
     storage->array[id - 1].data = SDL_const_cast(void *, value);
@@ -82,6 +100,7 @@ void SDL_TLSCleanup(void)
 {
     SDL_TLSData *storage;
 
+    /* Cleanup the storage for the current thread */
     storage = SDL_SYS_GetTLSData();
     if (storage) {
         unsigned int i;
@@ -92,6 +111,18 @@ void SDL_TLSCleanup(void)
         }
         SDL_SYS_SetTLSData(NULL);
         SDL_free(storage);
+        (void)SDL_AtomicDecRef(&SDL_tls_allocated);
+    }
+}
+
+void SDL_QuitTLSData(void)
+{
+    SDL_TLSCleanup();
+
+    if (SDL_AtomicGet(&SDL_tls_allocated) == 0) {
+        SDL_SYS_QuitTLSData();
+    } else {
+        /* Some thread hasn't called SDL_CleanupTLS() */
     }
 }
 
@@ -113,40 +144,27 @@ typedef struct SDL_TLSEntry
 static SDL_mutex *SDL_generic_TLS_mutex;
 static SDL_TLSEntry *SDL_generic_TLS;
 
+void SDL_Generic_InitTLSData(void)
+{
+    if (!SDL_generic_TLS_mutex) {
+        SDL_generic_TLS_mutex = SDL_CreateMutex();
+    }
+}
+
 SDL_TLSData *SDL_Generic_GetTLSData(void)
 {
     SDL_threadID thread = SDL_ThreadID();
     SDL_TLSEntry *entry;
     SDL_TLSData *storage = NULL;
 
-#ifndef SDL_THREADS_DISABLED
-    if (!SDL_generic_TLS_mutex) {
-        static SDL_SpinLock tls_lock;
-        SDL_AtomicLock(&tls_lock);
-        if (!SDL_generic_TLS_mutex) {
-            SDL_mutex *mutex = SDL_CreateMutex();
-            SDL_MemoryBarrierRelease();
-            SDL_generic_TLS_mutex = mutex;
-            if (!SDL_generic_TLS_mutex) {
-                SDL_AtomicUnlock(&tls_lock);
-                return NULL;
-            }
-        }
-        SDL_AtomicUnlock(&tls_lock);
-    }
-    SDL_MemoryBarrierAcquire();
     SDL_LockMutex(SDL_generic_TLS_mutex);
-#endif /* SDL_THREADS_DISABLED */
-
     for (entry = SDL_generic_TLS; entry; entry = entry->next) {
         if (entry->thread == thread) {
             storage = entry->storage;
             break;
         }
     }
-#ifndef SDL_THREADS_DISABLED
     SDL_UnlockMutex(SDL_generic_TLS_mutex);
-#endif
 
     return storage;
 }
@@ -155,8 +173,8 @@ int SDL_Generic_SetTLSData(SDL_TLSData *data)
 {
     SDL_threadID thread = SDL_ThreadID();
     SDL_TLSEntry *prev, *entry;
+    int retval = 0;
 
-    /* SDL_Generic_GetTLSData() is always called first, so we can assume SDL_generic_TLS_mutex */
     SDL_LockMutex(SDL_generic_TLS_mutex);
     prev = NULL;
     for (entry = SDL_generic_TLS; entry; entry = entry->next) {
@@ -175,21 +193,44 @@ int SDL_Generic_SetTLSData(SDL_TLSData *data)
         }
         prev = entry;
     }
-    if (!entry) {
+    if (!entry && data) {
         entry = (SDL_TLSEntry *)SDL_malloc(sizeof(*entry));
         if (entry) {
             entry->thread = thread;
             entry->storage = data;
             entry->next = SDL_generic_TLS;
             SDL_generic_TLS = entry;
+        } else {
+            retval = SDL_OutOfMemory();
         }
     }
     SDL_UnlockMutex(SDL_generic_TLS_mutex);
 
-    if (!entry) {
-        return SDL_OutOfMemory();
+    return retval;
+}
+
+void SDL_Generic_QuitTLSData(void)
+{
+    SDL_TLSEntry *entry;
+
+    /* This should have been cleaned up by the time we get here */
+    SDL_assert(!SDL_generic_TLS);
+    if (SDL_generic_TLS) {
+        SDL_LockMutex(SDL_generic_TLS_mutex);
+        for (entry = SDL_generic_TLS; entry; ) {
+            SDL_TLSEntry *next = entry->next;
+            SDL_free(entry->storage);
+            SDL_free(entry);
+            entry = next;
+        }
+        SDL_generic_TLS = NULL;
+        SDL_UnlockMutex(SDL_generic_TLS_mutex);
+    }
+
+    if (SDL_generic_TLS_mutex) {
+        SDL_DestroyMutex(SDL_generic_TLS_mutex);
+        SDL_generic_TLS_mutex = NULL;
     }
-    return 0;
 }
 
 /* Non-thread-safe global error variable */
@@ -328,6 +369,8 @@ SDL_Thread *SDL_CreateThreadWithStackSize(int(SDLCALL *fn)(void *),
     SDL_Thread *thread;
     int ret;
 
+    SDL_InitTLSData();
+
     /* Allocate memory for the thread info structure */
     thread = (SDL_Thread *)SDL_calloc(1, sizeof(*thread));
     if (!thread) {

+ 6 - 6
src/thread/SDL_thread_c.h

@@ -93,17 +93,17 @@ typedef struct
 /* This is how many TLS entries we allocate at once */
 #define TLS_ALLOC_CHUNKSIZE 4
 
-/* Get cross-platform, slow, thread local storage for this thread.
-   This is only intended as a fallback if getting real thread-local
-   storage fails or isn't supported on this platform.
- */
-extern SDL_TLSData *SDL_Generic_GetTLSData(void);
+extern void SDL_InitTLSData(void);
+extern void SDL_QuitTLSData(void);
 
-/* Set cross-platform, slow, thread local storage for this thread.
+/* Generic TLS support.
    This is only intended as a fallback if getting real thread-local
    storage fails or isn't supported on this platform.
  */
+extern void SDL_Generic_InitTLSData(void);
+extern SDL_TLSData *SDL_Generic_GetTLSData(void);
 extern int SDL_Generic_SetTLSData(SDL_TLSData *data);
+extern void SDL_Generic_QuitTLSData(void);
 
 #endif /* SDL_thread_c_h_ */
 

+ 10 - 0
src/thread/generic/SDL_systls.c

@@ -22,6 +22,11 @@
 #include "../../SDL_internal.h"
 #include "../SDL_thread_c.h"
 
+void SDL_SYS_InitTLSData(void)
+{
+    SDL_Generic_InitTLSData();
+}
+
 SDL_TLSData *SDL_SYS_GetTLSData(void)
 {
     return SDL_Generic_GetTLSData();
@@ -32,4 +37,9 @@ int SDL_SYS_SetTLSData(SDL_TLSData *data)
     return SDL_Generic_SetTLSData(data);
 }
 
+void SDL_SYS_QuitTLSData(void)
+{
+    SDL_Generic_QuitTLSData();
+}
+
 /* vi: set ts=4 sw=4 expandtab: */

+ 28 - 14
src/thread/pthread/SDL_systls.c

@@ -30,27 +30,27 @@
 static pthread_key_t thread_local_storage = INVALID_PTHREAD_KEY;
 static SDL_bool generic_local_storage = SDL_FALSE;
 
-SDL_TLSData *SDL_SYS_GetTLSData(void)
+void SDL_SYS_InitTLSData(void)
 {
     if (thread_local_storage == INVALID_PTHREAD_KEY && !generic_local_storage) {
-        static SDL_SpinLock lock;
-        SDL_AtomicLock(&lock);
-        if (thread_local_storage == INVALID_PTHREAD_KEY && !generic_local_storage) {
-            pthread_key_t storage;
-            if (pthread_key_create(&storage, NULL) == 0) {
-                SDL_MemoryBarrierRelease();
-                thread_local_storage = storage;
-            } else {
-                generic_local_storage = SDL_TRUE;
-            }
+        if (pthread_key_create(&thread_local_storage, NULL) != 0) {
+            thread_local_storage = INVALID_PTHREAD_KEY;
+            SDL_Generic_InitTLSData();
+            generic_local_storage = SDL_TRUE;
         }
-        SDL_AtomicUnlock(&lock);
     }
+}
+
+SDL_TLSData *SDL_SYS_GetTLSData(void)
+{
     if (generic_local_storage) {
         return SDL_Generic_GetTLSData();
     }
-    SDL_MemoryBarrierAcquire();
-    return (SDL_TLSData *)pthread_getspecific(thread_local_storage);
+
+    if (thread_local_storage != INVALID_PTHREAD_KEY) {
+        return (SDL_TLSData *)pthread_getspecific(thread_local_storage);
+    }
+    return NULL;
 }
 
 int SDL_SYS_SetTLSData(SDL_TLSData *data)
@@ -58,10 +58,24 @@ int SDL_SYS_SetTLSData(SDL_TLSData *data)
     if (generic_local_storage) {
         return SDL_Generic_SetTLSData(data);
     }
+
     if (pthread_setspecific(thread_local_storage, data) != 0) {
         return SDL_SetError("pthread_setspecific() failed");
     }
     return 0;
 }
 
+void SDL_SYS_QuitTLSData(void)
+{
+    if (generic_local_storage) {
+        SDL_Generic_QuitTLSData();
+        generic_local_storage = SDL_FALSE;
+    } else {
+        if (thread_local_storage != INVALID_PTHREAD_KEY) {
+            pthread_key_delete(thread_local_storage);
+            thread_local_storage = INVALID_PTHREAD_KEY;
+        }
+    }
+}
+
 /* vi: set ts=4 sw=4 expandtab: */

+ 19 - 6
src/thread/stdcpp/SDL_systhread.cpp

@@ -157,16 +157,29 @@ SDL_SYS_DetachThread(SDL_Thread *thread)
     }
 }
 
-extern "C" SDL_TLSData *
-SDL_SYS_GetTLSData(void)
+static thread_local SDL_TLSData *thread_local_storage;
+
+extern "C"
+void SDL_SYS_InitTLSData(void)
 {
-    return SDL_Generic_GetTLSData();
 }
 
-extern "C" int
-SDL_SYS_SetTLSData(SDL_TLSData *data)
+extern "C"
+SDL_TLSData * SDL_SYS_GetTLSData(void)
+{
+    return thread_local_storage;
+}
+
+extern "C"
+int SDL_SYS_SetTLSData(SDL_TLSData *data)
+{
+    thread_local_storage = data;
+    return 0;
+}
+
+extern "C"
+void SDL_SYS_QuitTLSData(void)
 {
-    return SDL_Generic_SetTLSData(data);
 }
 
 /* vi: set ts=4 sw=4 expandtab: */

+ 29 - 15
src/thread/windows/SDL_systls.c

@@ -43,27 +43,27 @@
 static DWORD thread_local_storage = TLS_OUT_OF_INDEXES;
 static SDL_bool generic_local_storage = SDL_FALSE;
 
-SDL_TLSData *SDL_SYS_GetTLSData(void)
+void SDL_SYS_InitTLSData(void)
 {
     if (thread_local_storage == TLS_OUT_OF_INDEXES && !generic_local_storage) {
-        static SDL_SpinLock lock;
-        SDL_AtomicLock(&lock);
-        if (thread_local_storage == TLS_OUT_OF_INDEXES && !generic_local_storage) {
-            DWORD storage = TlsAlloc();
-            if (storage != TLS_OUT_OF_INDEXES) {
-                SDL_MemoryBarrierRelease();
-                thread_local_storage = storage;
-            } else {
-                generic_local_storage = SDL_TRUE;
-            }
+        thread_local_storage = TlsAlloc();
+        if (thread_local_storage == TLS_OUT_OF_INDEXES) {
+            SDL_Generic_InitTLSData();
+            generic_local_storage = SDL_TRUE;
         }
-        SDL_AtomicUnlock(&lock);
     }
+}
+
+SDL_TLSData *SDL_SYS_GetTLSData(void)
+{
     if (generic_local_storage) {
         return SDL_Generic_GetTLSData();
     }
-    SDL_MemoryBarrierAcquire();
-    return (SDL_TLSData *)TlsGetValue(thread_local_storage);
+
+    if (thread_local_storage != TLS_OUT_OF_INDEXES) {
+        return (SDL_TLSData *)TlsGetValue(thread_local_storage);
+    }
+    return NULL;
 }
 
 int SDL_SYS_SetTLSData(SDL_TLSData *data)
@@ -71,12 +71,26 @@ int SDL_SYS_SetTLSData(SDL_TLSData *data)
     if (generic_local_storage) {
         return SDL_Generic_SetTLSData(data);
     }
+
     if (!TlsSetValue(thread_local_storage, data)) {
-        return SDL_SetError("TlsSetValue() failed");
+        return WIN_SetError("TlsSetValue()");
     }
     return 0;
 }
 
+void SDL_SYS_QuitTLSData(void)
+{
+    if (generic_local_storage) {
+        SDL_Generic_QuitTLSData();
+        generic_local_storage = SDL_FALSE;
+    } else {
+        if (thread_local_storage != TLS_OUT_OF_INDEXES) {
+            TlsFree(thread_local_storage);
+            thread_local_storage = TLS_OUT_OF_INDEXES;
+        }
+    }
+}
+
 #endif /* SDL_THREAD_WINDOWS */
 
 /* vi: set ts=4 sw=4 expandtab: */