要求如下:
gotype MyConcurrentMap struct{
// ... implement it
}
func (m *MyConcurrentMap) Put(key, val int){
// ... implement it
}
func (m *MyConcurrentMap) Get(key int, maxWaitingDuration time.Duration) (int ,error){
// ... implement it
}
gopackage myconcurrentmap
import (
"context"
"sync"
"time"
)
// 一道考题,实现一个MAP,要求如下
// 1、面向高并发
// 2、只存在插入和查询的操作O(1)的时间复杂度
// 3、查询时,若 key 存在,直接返回 val; 若 key 不存在,阻塞到 key val 对被放入后,获取 val 返回,等待指定时长仍未放入,返回超时错误
// 4、写出真实代码,不能有死锁或 painc 场景
type MyConcurrentMap struct {
// 互斥锁
sync.Mutex
// map 来保证时间复杂度满足要求
mp map[int]int
// keyToChan 用于存放 key 的 channel
keyToChan map[int]chan struct{}
}
// NewMyConcurrentMap 构造函数
func NewMyConcurrentMap() *MyConcurrentMap {
return &MyConcurrentMap{
mp: make(map[int]int),
keyToChan: make(map[int]chan struct{}),
}
}
func (m *MyConcurrentMap) Put(key, val int) {
m.Lock()
defer m.Unlock()
m.mp[key] = val
ch, ok := m.keyToChan[key]
if !ok {
return
}
// 如果 m.keyToChan[key] 中存在 ch,则说明有其他 goroutine 正在等待key数据的写入
// 这里使用 close 方法来关闭 channel,能够使得所有正在读 ch的 goroutine 都能被唤醒
// 并且要注意多次关闭同一个 channel会引发 panic,所以这里需要使用 select 多路监听去判断 ch
// 如果尝试从一个已关闭的 channel 读取的话,会直接返回对应类型数据的空值
select {
case <-ch:
return
default:
close(ch)
}
}
func (m *MyConcurrentMap) Get(key int, maxWaitingDuration time.Duration) (int, error) {
m.Lock()
val, ok := m.mp[key]
if ok {
m.Unlock()
return val, nil
}
// 如果 key 不存在,则往keyToChan这个 map 中添加一个 channel 用于 Put 函数通知我们对应的 key已经放入到 m.mp 中了
// 存在同一个 key 多次 Get 情况,所以要共享同一个 channel
ch, ok := m.keyToChan[key]
if !ok {
ch = make(chan struct{})
m.keyToChan[key] = ch
}
tCtx, cancelFunc := context.WithTimeout(context.Background(), maxWaitingDuration)
defer cancelFunc()
m.Unlock()
select {
case <-tCtx.Done():
return -1, tCtx.Err()
case <-ch:
}
// 这里还需加锁,此时如果还有其他 goroutine 正在写这个 key val 对的话,有可能会并发读写panic
m.Lock()
val = m.mp[key]
m.Unlock()
return val, nil
}
gopackage myconcurrentmap
import (
"context"
"testing"
"time"
)
func TestMyConcurrentMap_PutAndGet(t *testing.T) {
m := NewMyConcurrentMap()
// 测试Put和Get
m.Put(1, 100)
val, err := m.Get(1, time.Second)
if err != nil {
t.Errorf("Get returned an error: %v", err)
}
if val != 100 {
t.Errorf("Get returned incorrect value, expected 100 but got %d", val)
}
}
func TestMyConcurrentMap_GetWithTimeout(t *testing.T) {
m := NewMyConcurrentMap()
// 测试Get超时
_, err := m.Get(2, time.Second)
if err == nil || err != context.DeadlineExceeded {
t.Errorf("Get did not return expected error or did not timeout as expected")
}
}
func TestMyConcurrentMap_ConcurrentPutAndGet(t *testing.T) {
m := NewMyConcurrentMap()
// 测试并发Put和Get
go func() {
m.Put(3, 300)
}()
go func() {
time.Sleep(100 * time.Millisecond) // 确保Put操作已经开始
val, err := m.Get(3, time.Second)
if err != nil {
t.Errorf("Get returned an error: %v", err)
}
if val != 300 {
t.Errorf("Get returned incorrect value, expected 300 but got %d", val)
}
}()
time.Sleep(200 * time.Millisecond) // 确保Get操作已经完成
}
func TestMyConcurrentMap_HighConcurrency(t *testing.T) {
m := NewMyConcurrentMap()
// 并发执行Put操作
go func() {
for i := 0; i < 1000; i++ {
m.Put(i, i*10)
}
}()
// 并发执行Get操作
done := make(chan struct{})
for i := 0; i < 1000; i++ {
go func(key, expectedVal int) {
val, err := m.Get(key, time.Second)
if err != nil {
t.Errorf("Get returned an error: %v", err)
}
if val != expectedVal {
t.Errorf("Get returned incorrect value, expected %d but got %d", expectedVal, val)
}
done <- struct{}{}
}(i, i*10)
}
// 等待所有Get操作完成
for i := 0; i < 1000; i++ {
<-done
}
}
本文作者:relakkes
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!