Forráskód Böngészése

Added TLS initialization and shutdown functions

Fixes https://github.com/libsdl-org/SDL/issues/8576
Sam Lantinga 9 hónapja
szülő
commit
b517043936

+ 3 - 3
src/SDL.c

@@ -50,6 +50,7 @@
 #include "render/SDL_sysrender.h"
 #include "sensor/SDL_sensor_c.h"
 #include "stdlib/SDL_getenv_c.h"
+#include "thread/SDL_thread_c.h"
 #include "video/SDL_pixels_c.h"
 #include "video/SDL_video_c.h"
 
@@ -189,6 +190,7 @@ int SDL_InitSubSystem(Uint32 flags)
         return SDL_SetError("Application didn't initialize properly, did you include SDL_main.h in the file containing your main() function?");
     }
 
+    SDL_InitTLSData();
     SDL_InitLog();
     SDL_InitProperties();
     SDL_GetGlobalProperties();
@@ -563,10 +565,8 @@ void SDL_Quit(void)
     SDL_memset(SDL_SubsystemRefCount, 0x0, sizeof(SDL_SubsystemRefCount));
 
     SDL_FlushEventMemory(0);
-
-    SDL_CleanupTLS();
-
     SDL_FreeEnvironmentMemory();
+    SDL_QuitTLSData();
 
     SDL_bInMainQuit = SDL_FALSE;
 }

+ 6 - 0
src/thread/SDL_systhread.h

@@ -54,12 +54,18 @@ extern void SDL_SYS_WaitThread(SDL_Thread *thread);
 /* Mark thread as cleaned up as soon as it exits, without joining. */
 extern void SDL_SYS_DetachThread(SDL_Thread *thread);
 
+/* Initialize the global TLS data */
+extern void SDL_SYS_InitTLSData(void);
+
 /* Get the thread local storage for this thread */
 extern SDL_TLSData *SDL_SYS_GetTLSData(void);
 
 /* Set the thread local storage for this thread */
 extern int SDL_SYS_SetTLSData(SDL_TLSData *data);
 
+/* Quit the global TLS data */
+extern void SDL_SYS_QuitTLSData(void);
+
 /* A helper function for setting up a thread with a stack size. */
 extern SDL_Thread *SDL_CreateThreadWithStackSize(SDL_ThreadFunction fn, const char *name, size_t stacksize, void *data);
 

+ 69 - 23
src/thread/SDL_thread.c

@@ -26,6 +26,15 @@
 #include "SDL_systhread.h"
 #include "../SDL_error_c.h"
 
+/* The storage is local to the thread, but the IDs are global for the process */
+
+static SDL_AtomicInt SDL_tls_allocated;
+
+void SDL_InitTLSData(void)
+{
+    SDL_SYS_InitTLSData();
+}
+
 SDL_TLSID SDL_CreateTLS(void)
 {
     static SDL_AtomicInt SDL_tls_id;
@@ -51,6 +60,13 @@ int SDL_SetTLS(SDL_TLSID id, const void *value, SDL_TLSDestructorCallback destru
         return SDL_InvalidParamError("id");
     }
 
+    /* Make sure TLS is initialized.
+     * There's a race condition here if you are calling this from non-SDL threads
+     * and haven't called SDL_Init() on your main thread, but such is life.
+     */
+    SDL_InitTLSData();
+
+    /* Get the storage for the current thread */
     storage = SDL_SYS_GetTLSData();
     if (!storage || (id > storage->limit)) {
         unsigned int i, oldlimit, newlimit;
@@ -69,8 +85,10 @@ int SDL_SetTLS(SDL_TLSID id, const void *value, SDL_TLSDestructorCallback destru
             storage->array[i].destructor = NULL;
         }
         if (SDL_SYS_SetTLSData(storage) != 0) {
+            SDL_free(storage);
             return -1;
         }
+        SDL_AtomicIncRef(&SDL_tls_allocated);
     }
 
     storage->array[id - 1].data = SDL_const_cast(void *, value);
@@ -82,6 +100,7 @@ void SDL_CleanupTLS(void)
 {
     SDL_TLSData *storage;
 
+    /* Cleanup the storage for the current thread */
     storage = SDL_SYS_GetTLSData();
     if (storage) {
         unsigned int i;
@@ -92,6 +111,18 @@ void SDL_CleanupTLS(void)
         }
         SDL_SYS_SetTLSData(NULL);
         SDL_free(storage);
+        (void)SDL_AtomicDecRef(&SDL_tls_allocated);
+    }
+}
+
+void SDL_QuitTLSData(void)
+{
+    SDL_CleanupTLS();
+
+    if (SDL_AtomicGet(&SDL_tls_allocated) == 0) {
+        SDL_SYS_QuitTLSData();
+    } else {
+        /* Some thread hasn't called SDL_CleanupTLS() */
     }
 }
 
@@ -113,40 +144,27 @@ typedef struct SDL_TLSEntry
 static SDL_Mutex *SDL_generic_TLS_mutex;
 static SDL_TLSEntry *SDL_generic_TLS;
 
+void SDL_Generic_InitTLSData(void)
+{
+    if (!SDL_generic_TLS_mutex) {
+        SDL_generic_TLS_mutex = SDL_CreateMutex();
+    }
+}
+
 SDL_TLSData *SDL_Generic_GetTLSData(void)
 {
     SDL_ThreadID thread = SDL_GetCurrentThreadID();
     SDL_TLSEntry *entry;
     SDL_TLSData *storage = NULL;
 
-#ifndef SDL_THREADS_DISABLED
-    if (!SDL_generic_TLS_mutex) {
-        static SDL_SpinLock tls_lock;
-        SDL_LockSpinlock(&tls_lock);
-        if (!SDL_generic_TLS_mutex) {
-            SDL_Mutex *mutex = SDL_CreateMutex();
-            SDL_MemoryBarrierRelease();
-            SDL_generic_TLS_mutex = mutex;
-            if (!SDL_generic_TLS_mutex) {
-                SDL_UnlockSpinlock(&tls_lock);
-                return NULL;
-            }
-        }
-        SDL_UnlockSpinlock(&tls_lock);
-    }
-    SDL_MemoryBarrierAcquire();
     SDL_LockMutex(SDL_generic_TLS_mutex);
-#endif /* SDL_THREADS_DISABLED */
-
     for (entry = SDL_generic_TLS; entry; entry = entry->next) {
         if (entry->thread == thread) {
             storage = entry->storage;
             break;
         }
     }
-#ifndef SDL_THREADS_DISABLED
     SDL_UnlockMutex(SDL_generic_TLS_mutex);
-#endif
 
     return storage;
 }
@@ -155,8 +173,8 @@ int SDL_Generic_SetTLSData(SDL_TLSData *data)
 {
     SDL_ThreadID thread = SDL_GetCurrentThreadID();
     SDL_TLSEntry *prev, *entry;
+    int retval = 0;
 
-    /* SDL_Generic_GetTLSData() is always called first, so we can assume SDL_generic_TLS_mutex */
     SDL_LockMutex(SDL_generic_TLS_mutex);
     prev = NULL;
     for (entry = SDL_generic_TLS; entry; entry = entry->next) {
@@ -175,18 +193,44 @@ int SDL_Generic_SetTLSData(SDL_TLSData *data)
         }
         prev = entry;
     }
-    if (!entry) {
+    if (!entry && data) {
         entry = (SDL_TLSEntry *)SDL_malloc(sizeof(*entry));
         if (entry) {
             entry->thread = thread;
             entry->storage = data;
             entry->next = SDL_generic_TLS;
             SDL_generic_TLS = entry;
+        } else {
+            retval = -1;
         }
     }
     SDL_UnlockMutex(SDL_generic_TLS_mutex);
 
-    return entry ? 0 : -1;
+    return retval;
+}
+
+void SDL_Generic_QuitTLSData(void)
+{
+    SDL_TLSEntry *entry;
+
+    /* This should have been cleaned up by the time we get here */
+    SDL_assert(!SDL_generic_TLS);
+    if (SDL_generic_TLS) {
+        SDL_LockMutex(SDL_generic_TLS_mutex);
+        for (entry = SDL_generic_TLS; entry; ) {
+            SDL_TLSEntry *next = entry->next;
+            SDL_free(entry->storage);
+            SDL_free(entry);
+            entry = next;
+        }
+        SDL_generic_TLS = NULL;
+        SDL_UnlockMutex(SDL_generic_TLS_mutex);
+    }
+
+    if (SDL_generic_TLS_mutex) {
+        SDL_DestroyMutex(SDL_generic_TLS_mutex);
+        SDL_generic_TLS_mutex = NULL;
+    }
 }
 
 /* Non-thread-safe global error variable */
@@ -327,6 +371,8 @@ SDL_Thread *SDL_CreateThreadWithPropertiesRuntime(SDL_PropertiesID props,
         return NULL;
     }
 
+    SDL_InitTLSData();
+
     SDL_Thread *thread = (SDL_Thread *)SDL_calloc(1, sizeof(*thread));
     if (!thread) {
         return NULL;

+ 6 - 6
src/thread/SDL_thread_c.h

@@ -89,16 +89,16 @@ typedef struct
 /* This is how many TLS entries we allocate at once */
 #define TLS_ALLOC_CHUNKSIZE 4
 
-/* Get cross-platform, slow, thread local storage for this thread.
-   This is only intended as a fallback if getting real thread-local
-   storage fails or isn't supported on this platform.
- */
-extern SDL_TLSData *SDL_Generic_GetTLSData(void);
+extern void SDL_InitTLSData(void);
+extern void SDL_QuitTLSData(void);
 
-/* Set cross-platform, slow, thread local storage for this thread.
+/* Generic TLS support.
    This is only intended as a fallback if getting real thread-local
    storage fails or isn't supported on this platform.
  */
+extern void SDL_Generic_InitTLSData(void);
+extern SDL_TLSData *SDL_Generic_GetTLSData(void);
 extern int SDL_Generic_SetTLSData(SDL_TLSData *data);
+extern void SDL_Generic_QuitTLSData(void);
 
 #endif /* SDL_thread_c_h_ */

+ 11 - 0
src/thread/generic/SDL_systls.c

@@ -22,6 +22,11 @@
 #include "SDL_internal.h"
 #include "../SDL_thread_c.h"
 
+void SDL_SYS_InitTLSData(void)
+{
+    SDL_Generic_InitTLSData();
+}
+
 SDL_TLSData *SDL_SYS_GetTLSData(void)
 {
     return SDL_Generic_GetTLSData();
@@ -31,3 +36,9 @@ int SDL_SYS_SetTLSData(SDL_TLSData *data)
 {
     return SDL_Generic_SetTLSData(data);
 }
+
+void SDL_SYS_QuitTLSData(void)
+{
+    SDL_Generic_QuitTLSData();
+}
+

+ 28 - 14
src/thread/pthread/SDL_systls.c

@@ -29,27 +29,27 @@
 static pthread_key_t thread_local_storage = INVALID_PTHREAD_KEY;
 static SDL_bool generic_local_storage = SDL_FALSE;
 
-SDL_TLSData *SDL_SYS_GetTLSData(void)
+void SDL_SYS_InitTLSData(void)
 {
     if (thread_local_storage == INVALID_PTHREAD_KEY && !generic_local_storage) {
-        static SDL_SpinLock lock;
-        SDL_LockSpinlock(&lock);
-        if (thread_local_storage == INVALID_PTHREAD_KEY && !generic_local_storage) {
-            pthread_key_t storage;
-            if (pthread_key_create(&storage, NULL) == 0) {
-                SDL_MemoryBarrierRelease();
-                thread_local_storage = storage;
-            } else {
-                generic_local_storage = SDL_TRUE;
-            }
+        if (pthread_key_create(&thread_local_storage, NULL) != 0) {
+            thread_local_storage = INVALID_PTHREAD_KEY;
+            SDL_Generic_InitTLSData();
+            generic_local_storage = SDL_TRUE;
         }
-        SDL_UnlockSpinlock(&lock);
     }
+}
+
+SDL_TLSData *SDL_SYS_GetTLSData(void)
+{
     if (generic_local_storage) {
         return SDL_Generic_GetTLSData();
     }
-    SDL_MemoryBarrierAcquire();
-    return (SDL_TLSData *)pthread_getspecific(thread_local_storage);
+
+    if (thread_local_storage != INVALID_PTHREAD_KEY) {
+        return (SDL_TLSData *)pthread_getspecific(thread_local_storage);
+    }
+    return NULL;
 }
 
 int SDL_SYS_SetTLSData(SDL_TLSData *data)
@@ -57,8 +57,22 @@ int SDL_SYS_SetTLSData(SDL_TLSData *data)
     if (generic_local_storage) {
         return SDL_Generic_SetTLSData(data);
     }
+
     if (pthread_setspecific(thread_local_storage, data) != 0) {
         return SDL_SetError("pthread_setspecific() failed");
     }
     return 0;
 }
+
+void SDL_SYS_QuitTLSData(void)
+{
+    if (generic_local_storage) {
+        SDL_Generic_QuitTLSData();
+        generic_local_storage = SDL_FALSE;
+    } else {
+        if (thread_local_storage != INVALID_PTHREAD_KEY) {
+            pthread_key_delete(thread_local_storage);
+            thread_local_storage = INVALID_PTHREAD_KEY;
+        }
+    }
+}

+ 19 - 6
src/thread/stdcpp/SDL_systhread.cpp

@@ -143,14 +143,27 @@ SDL_SYS_DetachThread(SDL_Thread *thread)
     }
 }
 
-extern "C" SDL_TLSData *
-SDL_SYS_GetTLSData(void)
+static thread_local SDL_TLSData *thread_local_storage;
+
+extern "C"
+void SDL_SYS_InitTLSData(void)
 {
-    return SDL_Generic_GetTLSData();
 }
 
-extern "C" int
-SDL_SYS_SetTLSData(SDL_TLSData *data)
+extern "C"
+SDL_TLSData * SDL_SYS_GetTLSData(void)
+{
+    return thread_local_storage;
+}
+
+extern "C"
+int SDL_SYS_SetTLSData(SDL_TLSData *data)
+{
+    thread_local_storage = data;
+    return 0;
+}
+
+extern "C"
+void SDL_SYS_QuitTLSData(void)
 {
-    return SDL_Generic_SetTLSData(data);
 }

+ 29 - 15
src/thread/windows/SDL_systls.c

@@ -42,27 +42,27 @@
 static DWORD thread_local_storage = TLS_OUT_OF_INDEXES;
 static SDL_bool generic_local_storage = SDL_FALSE;
 
-SDL_TLSData *SDL_SYS_GetTLSData(void)
+void SDL_SYS_InitTLSData(void)
 {
     if (thread_local_storage == TLS_OUT_OF_INDEXES && !generic_local_storage) {
-        static SDL_SpinLock lock;
-        SDL_LockSpinlock(&lock);
-        if (thread_local_storage == TLS_OUT_OF_INDEXES && !generic_local_storage) {
-            DWORD storage = TlsAlloc();
-            if (storage != TLS_OUT_OF_INDEXES) {
-                SDL_MemoryBarrierRelease();
-                thread_local_storage = storage;
-            } else {
-                generic_local_storage = SDL_TRUE;
-            }
+        thread_local_storage = TlsAlloc();
+        if (thread_local_storage == TLS_OUT_OF_INDEXES) {
+            SDL_Generic_InitTLSData();
+            generic_local_storage = SDL_TRUE;
         }
-        SDL_UnlockSpinlock(&lock);
     }
+}
+
+SDL_TLSData *SDL_SYS_GetTLSData(void)
+{
     if (generic_local_storage) {
         return SDL_Generic_GetTLSData();
     }
-    SDL_MemoryBarrierAcquire();
-    return (SDL_TLSData *)TlsGetValue(thread_local_storage);
+
+    if (thread_local_storage != TLS_OUT_OF_INDEXES) {
+        return (SDL_TLSData *)TlsGetValue(thread_local_storage);
+    }
+    return NULL;
 }
 
 int SDL_SYS_SetTLSData(SDL_TLSData *data)
@@ -70,10 +70,24 @@ int SDL_SYS_SetTLSData(SDL_TLSData *data)
     if (generic_local_storage) {
         return SDL_Generic_SetTLSData(data);
     }
+
     if (!TlsSetValue(thread_local_storage, data)) {
-        return SDL_SetError("TlsSetValue() failed");
+        return WIN_SetError("TlsSetValue()");
     }
     return 0;
 }
 
+void SDL_SYS_QuitTLSData(void)
+{
+    if (generic_local_storage) {
+        SDL_Generic_QuitTLSData();
+        generic_local_storage = SDL_FALSE;
+    } else {
+        if (thread_local_storage != TLS_OUT_OF_INDEXES) {
+            TlsFree(thread_local_storage);
+            thread_local_storage = TLS_OUT_OF_INDEXES;
+        }
+    }
+}
+
 #endif /* SDL_THREAD_WINDOWS */