题目描述:
给出一个长度为 $ n $ 链表的头结点head
,按照升序排列,并且返回排序之后的链表。
数据范围:
$ 0\le n \le 5 \times 10^4 $
题解:
首先观察数据范围,发现只能使用 $ O(n\log(n)) $ 的算法解决。可供选择的有快排,堆排,归并。链表比较适合使用归并排序。
与一般归并不同的是,需要合并成一个新的链表,注意需要返回头指针。一般的数组排序是需要将合并后的数据拷贝回原来的数组的,因此只需要记录每段的起点终点即可。因为链表的合并,最后返回了一个新的头指针,并且没有拷回。需要在分治与合并的都返回头指针。
找中点可以使用快慢指针。合并时剩余的部分可以直接接到链表的尾部,不用一个一个遍历。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
| class Solution { public: ListNode *merge_sort(ListNode *head, ListNode *tail) { if (head->next == tail) { head->next = nullptr; return head; } ListNode *fast = head, *slow = head; while (fast != tail) { fast = fast->next; if (fast != tail) fast = fast->next; slow = slow->next; } ListNode *left = merge_sort(head, slow); ListNode *right = merge_sort(slow, tail); return mix_merge(left, right); }
ListNode *mix_merge(ListNode *left, ListNode *right) { ListNode *head = new ListNode; ListNode *tail = head; ListNode *l = left, *r = right; while (l != nullptr && r != nullptr) { if (l->val < r->val) { tail->next = l; tail = l; l = l->next; } else { tail->next = r; tail = r; r = r->next; } } if (l != nullptr) { tail->next = l; } if (r != nullptr) { tail->next = r; } return head->next; } ListNode *sortList(ListNode *head) { if (head == nullptr || head->next == nullptr) return head; return merge_sort(head, nullptr); } };
|
循环版本:空间复杂度 $ O(1) $
需要先统计长度,然后枚举子区间长度,从 $ 1 $ 开始枚举,每次翻倍。
每次找到两个长度为 length
的链表,把结尾处理一下,然后合并,接到新链表的结尾。
然后继续,一直到无法找到。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
| class Solution { public: const static int maxn = 1e5 + 10; const static int maxm = 1e5 + 10; const static int INF = 0x3f3f3f3f; ListNode *sortList(ListNode *head) { if (head == nullptr || head->next == nullptr) return head; ListNode *vir_head = new ListNode(0, head); int list_length = 0; ListNode *cur = vir_head->next; while (cur) { cur = cur->next; list_length++; } for (int length = 1; length < list_length; length <<= 1) { ListNode *cur = vir_head->next; ListNode *tail = vir_head; while (cur) { ListNode *head1 = cur; for (int i = 1; i < length && cur != nullptr; ++i) { cur = cur->next; } ListNode *head2 = nullptr; if (cur) { head2 = cur->next; cur->next = nullptr; cur = head2; } for (int i = 1; i < length && cur != nullptr; ++i) { cur = cur->next; } ListNode *nex = nullptr; if (cur) { nex = cur->next; cur->next = nullptr; } ListNode *m_head = merge(head1, head2); tail->next = m_head; while (tail->next) { tail = tail->next; } cur = nex; } } return vir_head->next; }
ListNode *merge(ListNode *head1, ListNode *head2) { ListNode *head = new ListNode(0, nullptr); ListNode *tail = head; while (head1 != nullptr && head2 != nullptr) { if (head1->val < head2->val) { tail->next = head1; tail = head1; head1 = head1->next; } else { tail->next = head2; tail = head2; head2 = head2->next; } } if (head1) { tail->next = head1; } if (head2) { tail->next = head2; } return head->next; } };
|